mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	batch : auto-gen positions + verify multi-sequence input (#14177)
* batch : verify multi-sequence input batches ggml-ci * cont : auto-gen positions + verify multi-seq input ggml-ci * cont : first print debug info, then perform validation ggml-ci * cont : fix position auto-gen + add comments ggml-ci
This commit is contained in:
		| @@ -243,14 +243,14 @@ extern "C" { | ||||
|  | ||||
|     typedef bool (*llama_progress_callback)(float progress, void * user_data); | ||||
|  | ||||
|     // Input data for llama_decode | ||||
|     // Input data for llama_encode/llama_decode | ||||
|     // A llama_batch object can contain input about one or many sequences | ||||
|     // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | ||||
|     // | ||||
|     // - token  : the token ids of the input (used when embd is NULL) | ||||
|     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | ||||
|     // - pos    : the positions of the respective token in the sequence | ||||
|     //            (if set to NULL, the token position will be tracked automatically by llama_decode) | ||||
|     //            (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) | ||||
|     // - seq_id : the sequence to which the respective token belongs | ||||
|     //            (if set to NULL, the sequence ID will be assumed to be 0) | ||||
|     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| #include "llama-impl.h" | ||||
| #include "llama-cparams.h" | ||||
| #include "llama-vocab.h" | ||||
| #include "llama-memory.h" | ||||
|  | ||||
| #include <cassert> | ||||
| #include <cstring> | ||||
| @@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple | ||||
| llama_batch_allocr::llama_batch_allocr() { | ||||
|     const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); | ||||
|     debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; | ||||
|  | ||||
|     seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES); | ||||
|     seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES); | ||||
|     for (auto & cur : seq_cpl) { | ||||
|         cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES); | ||||
|     } | ||||
| } | ||||
|  | ||||
| bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) { | ||||
| bool llama_batch_allocr::init( | ||||
|         const llama_batch & batch_inp, | ||||
|         const llama_vocab & vocab, | ||||
|         const llama_memory_i * memory) { | ||||
|     clear(); | ||||
|  | ||||
|     batch = batch_inp; | ||||
|  | ||||
|     GGML_ASSERT(batch.n_tokens > 0); | ||||
|  | ||||
|     if (!batch.pos) { | ||||
|         if (batch.seq_id) { | ||||
|             LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__); | ||||
|             return false; | ||||
|         } | ||||
|     } | ||||
|     // | ||||
|     // validate input batch | ||||
|     // | ||||
|  | ||||
|     if (batch.token) { | ||||
|         for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||
| @@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (!batch.pos) { | ||||
|         assert(p0 >= 0); | ||||
|         pos.resize(batch.n_tokens); | ||||
|         for (int32_t i = 0; i < batch.n_tokens; i++) { | ||||
|             pos[i] = p0 + i; | ||||
|         } | ||||
|         batch.pos = pos.data(); | ||||
|     } | ||||
|     // | ||||
|     // auto-generate missing fields | ||||
|     // | ||||
|  | ||||
|     if (!batch.n_seq_id) { | ||||
|         n_seq_id.resize(batch.n_tokens); | ||||
| @@ -349,6 +351,32 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & | ||||
|         batch.seq_id = seq_id.data(); | ||||
|     } | ||||
|  | ||||
|     if (!batch.pos) { | ||||
|         pos.resize(batch.n_tokens); | ||||
|  | ||||
|         // initialize the starting position for each sequence based on the positions in the memory | ||||
|         llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES]; | ||||
|         for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|             if (!memory) { | ||||
|                 p0[s] = 0; | ||||
|             } else { | ||||
|                 p0[s] = memory->seq_pos_max(s) + 1; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         for (int32_t i = 0; i < batch.n_tokens; i++) { | ||||
|             const llama_seq_id seq_id = batch.seq_id[i][0]; | ||||
|  | ||||
|             pos[i] = p0[seq_id]; | ||||
|  | ||||
|             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { | ||||
|                 p0[batch.seq_id[i][s]] = pos[i] + 1; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         batch.pos = pos.data(); | ||||
|     } | ||||
|  | ||||
|     if (!batch.logits) { | ||||
|         // by default return the output only for the last token | ||||
|         output.resize(batch.n_tokens); | ||||
| @@ -356,13 +384,36 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & | ||||
|         batch.logits = output.data(); | ||||
|     } | ||||
|  | ||||
|     // | ||||
|     // compute stats | ||||
|     // | ||||
|  | ||||
|     for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||
|         n_outputs += batch.logits[i] != 0; | ||||
|     } | ||||
|  | ||||
|     // determine coupled sequences | ||||
|     // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them | ||||
|     for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||
|         for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { | ||||
|             seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]); | ||||
|  | ||||
|             if (s > 0) { | ||||
|                 const llama_seq_id s0 = batch.seq_id[i][0]; | ||||
|                 const llama_seq_id s1 = batch.seq_id[i][s]; | ||||
|  | ||||
|                 // mark that sequence s1 is coupled to s0 | ||||
|                 seq_cpl[s1][s0] = true; | ||||
|  | ||||
|                 // note: the other way around is not necessary for now | ||||
|                 //seq_cpl[s0][s1] = true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (debug > 0) { | ||||
|         LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0); | ||||
|         LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__, batch.n_tokens); | ||||
|         LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__); | ||||
|         LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__,          batch.n_tokens); | ||||
|         LLAMA_LOG_DEBUG("%s:   token     = %p\n", __func__, (void *) batch.token); | ||||
|         LLAMA_LOG_DEBUG("%s:   embd      = %p\n", __func__, (void *) batch.embd); | ||||
|         LLAMA_LOG_DEBUG("%s:   pos       = %p\n", __func__, (void *) batch.pos); | ||||
| @@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & | ||||
|                         batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]); | ||||
|             } | ||||
|             LLAMA_LOG_DEBUG("%s:   ]\n", __func__); | ||||
|  | ||||
|             LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__); | ||||
|             for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) { | ||||
|                 if (seq_pos[s0].empty()) { | ||||
|                     continue; | ||||
|                 } | ||||
|  | ||||
|                 std::stringstream ss; | ||||
|                 for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) { | ||||
|                     if (seq_cpl[s0][s1]) { | ||||
|                         ss << s1 << " "; | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n", | ||||
|                         __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str()); | ||||
|             } | ||||
|             LLAMA_LOG_DEBUG("%s:   ]\n", __func__); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // | ||||
|     // consistency checks | ||||
|     // | ||||
|  | ||||
|     for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|         if (seq_pos[s].empty()) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { | ||||
|             LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { | ||||
|             LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); | ||||
|             return false; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (memory) { | ||||
|         for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) { | ||||
|             for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) { | ||||
|                 if (seq_cpl[s0][s1]) { | ||||
|                     if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || | ||||
|                         memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { | ||||
|                         LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1); | ||||
|                         return false; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const { | ||||
|     return n_outputs; | ||||
| } | ||||
|  | ||||
| llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const { | ||||
|     return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin(); | ||||
| } | ||||
|  | ||||
| llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { | ||||
|     return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin(); | ||||
| } | ||||
|  | ||||
| void llama_batch_allocr::clear() { | ||||
|     n_outputs = 0; | ||||
|  | ||||
| @@ -426,6 +537,14 @@ void llama_batch_allocr::clear() { | ||||
|     n_seq_id.clear(); | ||||
|     seq_id.clear(); | ||||
|     output.clear(); | ||||
|  | ||||
|     for (auto & cur : seq_pos) { | ||||
|         cur.clear(); | ||||
|     } | ||||
|  | ||||
|     for (auto & cur : seq_cpl) { | ||||
|         std::fill(cur.begin(), cur.end(), false); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
|  | ||||
| #include <array> | ||||
| #include <vector> | ||||
| #include <set> | ||||
|  | ||||
| // very similar to llama_batch, | ||||
| // but has more metadata about sequences | ||||
| @@ -77,18 +78,25 @@ struct llama_sbatch { | ||||
|     llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); | ||||
| }; | ||||
|  | ||||
| // temporary allocate memory for the input batch if needed | ||||
| // a helper for sanitizing and fulfilling a batch | ||||
| class llama_batch_allocr { | ||||
| public: | ||||
|     llama_batch_allocr(); | ||||
|  | ||||
|     // optionally fulfill the batch returned by llama_batch_get_one | ||||
|     bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0); | ||||
|     // sanitize and auto-gen missing data in the input batch | ||||
|     // memory is optional. if provided will be used to check for sequence continuity and to determine the positions | ||||
|     bool init( | ||||
|             const llama_batch & batch_inp, | ||||
|             const llama_vocab & vocab, | ||||
|             const llama_memory_i * memory); | ||||
|  | ||||
|     const llama_batch & get_batch() const; | ||||
|  | ||||
|     uint32_t get_n_outputs() const; | ||||
|  | ||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const; | ||||
|     llama_pos seq_pos_max(llama_seq_id seq_id) const; | ||||
|  | ||||
| private: | ||||
|     void clear(); | ||||
|  | ||||
| @@ -103,5 +111,8 @@ private: | ||||
|     std::vector<llama_seq_id *> seq_id; | ||||
|     std::vector<int8_t>         output; | ||||
|  | ||||
|     std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s | ||||
|     std::vector<std::vector<bool>>   seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 | ||||
|  | ||||
|     int debug; | ||||
| }; | ||||
|   | ||||
| @@ -727,9 +727,8 @@ int llama_context::encode(const llama_batch & batch_inp) { | ||||
|         return -1; | ||||
|     } | ||||
|  | ||||
|     // temporary allocate memory for the input batch if needed | ||||
|     // note: during encode, we always pass the full sequence starting from pos = 0 | ||||
|     if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) { | ||||
|     if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); | ||||
|         return -1; | ||||
|     } | ||||
| @@ -895,8 +894,7 @@ int llama_context::decode(const llama_batch & batch_inp) { | ||||
|         return -1; | ||||
|     } | ||||
|  | ||||
|     // temporary allocate memory for the input batch if needed | ||||
|     if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) { | ||||
|     if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); | ||||
|         return -1; | ||||
|     } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
|  | ||||
| #include <cstdint> | ||||
|  | ||||
| // TODO: rename to something shorter | ||||
| #define LLAMA_MAX_PARALLEL_SEQUENCES 64 | ||||
|  | ||||
| struct llama_cparams { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov