mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	batch : rework llama_batch_allocr (#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
This commit is contained in:
		| @@ -1,5 +1,9 @@ | |||||||
| #include "llama-batch.h" | #include "llama-batch.h" | ||||||
|  |  | ||||||
|  | #include "llama-impl.h" | ||||||
|  | #include "llama-cparams.h" | ||||||
|  | #include "llama-vocab.h" | ||||||
|  |  | ||||||
| #include <cassert> | #include <cassert> | ||||||
| #include <cstring> | #include <cstring> | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| @@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple | |||||||
|             ); |             ); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { | llama_batch_allocr::llama_batch_allocr() = default; | ||||||
|     batch = in_batch; |  | ||||||
|  | bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) { | ||||||
|  |     clear(); | ||||||
|  |  | ||||||
|  |     batch = batch_inp; | ||||||
|  |  | ||||||
|     GGML_ASSERT(batch.n_tokens > 0); |     GGML_ASSERT(batch.n_tokens > 0); | ||||||
|  |  | ||||||
|  |     if (!batch.pos) { | ||||||
|  |         if (batch.seq_id) { | ||||||
|  |             LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__); | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (batch.token) { | ||||||
|  |         for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||||
|  |             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { | ||||||
|  |                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (batch.seq_id) { | ||||||
|  |         for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||||
|  |             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { | ||||||
|  |                 if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { | ||||||
|  |                     LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES); | ||||||
|  |                     return false; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     if (!batch.pos) { |     if (!batch.pos) { | ||||||
|         assert(p0 >= 0); |         assert(p0 >= 0); | ||||||
|         pos.resize(batch.n_tokens); |         pos.resize(batch.n_tokens); | ||||||
| @@ -290,6 +327,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 | |||||||
|         } |         } | ||||||
|         batch.pos = pos.data(); |         batch.pos = pos.data(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (!batch.n_seq_id) { |     if (!batch.n_seq_id) { | ||||||
|         n_seq_id.resize(batch.n_tokens); |         n_seq_id.resize(batch.n_tokens); | ||||||
|         for (int32_t i = 0; i < batch.n_tokens; i++) { |         for (int32_t i = 0; i < batch.n_tokens; i++) { | ||||||
| @@ -297,6 +335,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 | |||||||
|         } |         } | ||||||
|         batch.n_seq_id = n_seq_id.data(); |         batch.n_seq_id = n_seq_id.data(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (!batch.seq_id) { |     if (!batch.seq_id) { | ||||||
|         seq_id.resize(batch.n_tokens + 1); |         seq_id.resize(batch.n_tokens + 1); | ||||||
|         seq_id[batch.n_tokens] = NULL; |         seq_id[batch.n_tokens] = NULL; | ||||||
| @@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 | |||||||
|         } |         } | ||||||
|         batch.seq_id = seq_id.data(); |         batch.seq_id = seq_id.data(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (!batch.logits) { |     if (!batch.logits) { | ||||||
|         // by default return the output only for the last token |         // by default return the output only for the last token | ||||||
|         output.resize(batch.n_tokens); |         output.resize(batch.n_tokens); | ||||||
|         output[output.size() - 1] = true; |         output[output.size() - 1] = true; | ||||||
|         batch.logits = output.data(); |         batch.logits = output.data(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||||
|  |         n_outputs += batch.logits[i] != 0; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const llama_batch & llama_batch_allocr::get_batch() const { | ||||||
|  |     return batch; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | uint32_t llama_batch_allocr::get_n_outputs() const { | ||||||
|  |     return n_outputs; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void llama_batch_allocr::clear() { | ||||||
|  |     n_outputs = 0; | ||||||
|  |  | ||||||
|  |     batch = {}; | ||||||
|  |     pos.clear(); | ||||||
|  |     n_seq_id.clear(); | ||||||
|  |     seq_id.clear(); | ||||||
|  |     output.clear(); | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // | ||||||
|   | |||||||
| @@ -18,8 +18,8 @@ struct llama_ubatch { | |||||||
|     llama_token  *  token;    // [n_tokens] |     llama_token  *  token;    // [n_tokens] | ||||||
|     float        *  embd;     // [n_embd, n_tokens] |     float        *  embd;     // [n_embd, n_tokens] | ||||||
|     llama_pos    *  pos;      // [n_tokens] |     llama_pos    *  pos;      // [n_tokens] | ||||||
|     int32_t      *  n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence |     int32_t      *  n_seq_id; // [n_seqs] | ||||||
|     llama_seq_id ** seq_id;   // [n_seqs] // TODO: become llama_seq_id * seq_id; |     llama_seq_id ** seq_id;   // [n_seqs] | ||||||
|     int8_t       *  output;   // [n_tokens] |     int8_t       *  output;   // [n_tokens] | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -78,15 +78,28 @@ struct llama_sbatch { | |||||||
| }; | }; | ||||||
|  |  | ||||||
| // temporary allocate memory for the input batch if needed | // temporary allocate memory for the input batch if needed | ||||||
| struct llama_batch_allocr { | class llama_batch_allocr { | ||||||
|     struct llama_batch batch; | public: | ||||||
|  |     llama_batch_allocr(); | ||||||
|  |  | ||||||
|  |     // optionally fulfill the batch returned by llama_batch_get_one | ||||||
|  |     bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0); | ||||||
|  |  | ||||||
|  |     const llama_batch & get_batch() const; | ||||||
|  |  | ||||||
|  |     uint32_t get_n_outputs() const; | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |     void clear(); | ||||||
|  |  | ||||||
|  |     llama_batch batch; | ||||||
|  |  | ||||||
|  |     uint32_t n_outputs; | ||||||
|  |  | ||||||
|     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id |     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id | ||||||
|  |  | ||||||
|     std::vector<llama_pos>      pos; |     std::vector<llama_pos>      pos; | ||||||
|     std::vector<int32_t>        n_seq_id; |     std::vector<int32_t>        n_seq_id; | ||||||
|     std::vector<llama_seq_id *> seq_id; |     std::vector<llama_seq_id *> seq_id; | ||||||
|     std::vector<int8_t>         output; |     std::vector<int8_t>         output; | ||||||
|  |  | ||||||
|     // optionally fulfill the batch returned by llama_batch_get_one |  | ||||||
|     llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); |  | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| #include "llama-context.h" | #include "llama-context.h" | ||||||
|  |  | ||||||
| #include "llama-impl.h" | #include "llama-impl.h" | ||||||
|  | #include "llama-batch.h" | ||||||
| #include "llama-io.h" | #include "llama-io.h" | ||||||
| #include "llama-memory.h" | #include "llama-memory.h" | ||||||
| #include "llama-mmap.h" | #include "llama-mmap.h" | ||||||
| @@ -18,7 +19,8 @@ | |||||||
| llama_context::llama_context( | llama_context::llama_context( | ||||||
|         const llama_model & model, |         const llama_model & model, | ||||||
|               llama_context_params params) : |               llama_context_params params) : | ||||||
|     model(model) { |     model(model), | ||||||
|  |     batch_allocr(std::make_unique<llama_batch_allocr>()) { | ||||||
|     LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); |     LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); | ||||||
|  |  | ||||||
|     t_start_us = model.t_start_us; |     t_start_us = model.t_start_us; | ||||||
| @@ -494,7 +496,7 @@ float * llama_context::get_logits() { | |||||||
| } | } | ||||||
|  |  | ||||||
| float * llama_context::get_logits_ith(int32_t i) { | float * llama_context::get_logits_ith(int32_t i) { | ||||||
|     int32_t j = -1; |     int64_t j = -1; | ||||||
|  |  | ||||||
|     try { |     try { | ||||||
|         if (logits == nullptr) { |         if (logits == nullptr) { | ||||||
| @@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) { | |||||||
|         } |         } | ||||||
|         if (j >= n_outputs) { |         if (j >= n_outputs) { | ||||||
|             // This should not happen |             // This should not happen | ||||||
|             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); |             throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return logits + j*model.vocab.n_tokens(); |         return logits + j*model.vocab.n_tokens(); | ||||||
| @@ -536,7 +538,7 @@ float * llama_context::get_embeddings() { | |||||||
| } | } | ||||||
|  |  | ||||||
| float * llama_context::get_embeddings_ith(int32_t i) { | float * llama_context::get_embeddings_ith(int32_t i) { | ||||||
|     int32_t j = -1; |     int64_t j = -1; | ||||||
|  |  | ||||||
|     try { |     try { | ||||||
|         if (embd == nullptr) { |         if (embd == nullptr) { | ||||||
| @@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { | |||||||
|         } |         } | ||||||
|         if (j >= n_outputs) { |         if (j >= n_outputs) { | ||||||
|             // This should not happen |             // This should not happen | ||||||
|             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); |             throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return embd + j*model.hparams.n_embd; |         return embd + j*model.hparams.n_embd; | ||||||
| @@ -719,40 +721,27 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, | |||||||
|     return res; |     return res; | ||||||
| } | } | ||||||
|  |  | ||||||
| int llama_context::encode(llama_batch & inp_batch) { | int llama_context::encode(const llama_batch & batch_inp) { | ||||||
|     if (inp_batch.n_tokens == 0) { |     if (batch_inp.n_tokens == 0) { | ||||||
|         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // temporary allocate memory for the input batch if needed |     // temporary allocate memory for the input batch if needed | ||||||
|     // note: during encode, we always pass the full sequence starting from pos = 0 |     // note: during encode, we always pass the full sequence starting from pos = 0 | ||||||
|     llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0); |     if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) { | ||||||
|  |         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); | ||||||
|     const llama_batch & batch = batch_allocr.batch; |  | ||||||
|     const int32_t n_tokens = batch.n_tokens; |  | ||||||
|  |  | ||||||
|     const auto & hparams = model.hparams; |  | ||||||
|  |  | ||||||
|     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT |  | ||||||
|  |  | ||||||
|     // TODO: move the validation to the llama_batch_allocr |  | ||||||
|     if (batch.token) { |  | ||||||
|         for (int32_t i = 0; i < n_tokens; ++i) { |  | ||||||
|             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { |  | ||||||
|                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); |  | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|             if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { |     const llama_batch & batch = batch_allocr->get_batch(); | ||||||
|                 LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); |  | ||||||
|                 throw -1; |     const uint32_t n_tokens = batch.n_tokens; | ||||||
|             } |  | ||||||
|         } |     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot |     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot | ||||||
|     GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); |     GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); | ||||||
|  |  | ||||||
|     if (t_compute_start_us == 0) { |     if (t_compute_start_us == 0) { | ||||||
|         t_compute_start_us = ggml_time_us(); |         t_compute_start_us = ggml_time_us(); | ||||||
| @@ -763,6 +752,8 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|  |  | ||||||
|     n_queued_tokens += n_tokens; |     n_queued_tokens += n_tokens; | ||||||
|  |  | ||||||
|  |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
|     const int64_t n_embd = hparams.n_embd; |     const int64_t n_embd = hparams.n_embd; | ||||||
|  |  | ||||||
|     llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true); |     llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true); | ||||||
| @@ -775,7 +766,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|         return -2; |         return -2; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     for (int32_t i = 0; i < n_tokens; ++i) { |     for (uint32_t i = 0; i < n_tokens; ++i) { | ||||||
|         output_ids[i] = i; |         output_ids[i] = i; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -831,7 +822,8 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|  |  | ||||||
|                     GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits |                     GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits | ||||||
|  |  | ||||||
|                     for (int32_t i = 0; i < n_tokens; i++) { |                     // TODO: fix indexing [UBATCH_IDX] | ||||||
|  |                     for (uint32_t i = 0; i < n_tokens; i++) { | ||||||
|                         const llama_seq_id seq_id = ubatch.seq_id[i][0]; |                         const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||||||
|                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { |                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { | ||||||
|                             continue; |                             continue; | ||||||
| @@ -846,6 +838,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|                     auto & embd_seq_out = embd_seq; |                     auto & embd_seq_out = embd_seq; | ||||||
|                     const uint32_t n_cls_out = hparams.n_cls_out; |                     const uint32_t n_cls_out = hparams.n_cls_out; | ||||||
|  |  | ||||||
|  |                     // TODO: fix indexing [UBATCH_IDX] | ||||||
|                     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { |                     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { | ||||||
|                         const llama_seq_id seq_id = ubatch.seq_id[s][0]; |                         const llama_seq_id seq_id = ubatch.seq_id[s][0]; | ||||||
|                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { |                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { | ||||||
| @@ -878,13 +871,11 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|         memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); |         memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); | ||||||
|  |  | ||||||
|         // remember the sequence ids used during the encoding - needed for cross attention later |         // remember the sequence ids used during the encoding - needed for cross attention later | ||||||
|         // TODO: the seuqence indexing here is likely not correct in the general case |  | ||||||
|         //       probably works only for split_simple |  | ||||||
|         cross.seq_ids_enc.resize(n_tokens); |         cross.seq_ids_enc.resize(n_tokens); | ||||||
|         for (int32_t i = 0; i < n_tokens; i++) { |         for (uint32_t i = 0; i < n_tokens; i++) { | ||||||
|             cross.seq_ids_enc[i].clear(); |             cross.seq_ids_enc[i].clear(); | ||||||
|             for (int s = 0; s < ubatch.n_seq_id[i]; s++) { |             for (int s = 0; s < batch.n_seq_id[i]; s++) { | ||||||
|                 llama_seq_id seq_id = ubatch.seq_id[i][s]; |                 llama_seq_id seq_id = batch.seq_id[i][s]; | ||||||
|                 cross.seq_ids_enc[i].insert(seq_id); |                 cross.seq_ids_enc[i].insert(seq_id); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -893,68 +884,44 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| int llama_context::decode(llama_batch & inp_batch) { | int llama_context::decode(const llama_batch & batch_inp) { | ||||||
|     if (!memory) { |     if (!memory) { | ||||||
|         LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); |         LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); | ||||||
|         return encode(inp_batch); |         return encode(batch_inp); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (inp_batch.n_tokens == 0) { |     if (batch_inp.n_tokens == 0) { | ||||||
|         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (!inp_batch.pos) { |     // temporary allocate memory for the input batch if needed | ||||||
|         if (inp_batch.seq_id) { |     if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) { | ||||||
|             LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__); |         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // temporary allocate memory for the input batch if needed |     const llama_batch & batch = batch_allocr->get_batch(); | ||||||
|     llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1); |  | ||||||
|  |  | ||||||
|     const llama_batch & batch = batch_allocr.batch; |  | ||||||
|  |  | ||||||
|     const auto & vocab   = model.vocab; |     const auto & vocab   = model.vocab; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
|     const int32_t n_vocab = vocab.n_tokens(); |     const int32_t n_vocab = vocab.n_tokens(); | ||||||
|  |  | ||||||
|     const int64_t n_tokens_all = batch.n_tokens; |  | ||||||
|     const int64_t n_embd  = hparams.n_embd; |     const int64_t n_embd  = hparams.n_embd; | ||||||
|  |  | ||||||
|  |     const uint32_t n_tokens_all = batch.n_tokens; | ||||||
|  |  | ||||||
|     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT |     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT | ||||||
|  |  | ||||||
|     // TODO: move the validation to the llama_batch_allocr |  | ||||||
|     if (batch.token) { |  | ||||||
|         for (int64_t i = 0; i < n_tokens_all; ++i) { |  | ||||||
|             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { |  | ||||||
|                 LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); |  | ||||||
|                 return -1; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { |  | ||||||
|                 LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); |  | ||||||
|                 return -1; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // this indicates we are doing pooled embedding |     // this indicates we are doing pooled embedding | ||||||
|     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; |     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; | ||||||
|  |  | ||||||
|     int64_t n_outputs_all = 0; |     const uint32_t n_outputs_all = batch_allocr->get_n_outputs(); | ||||||
|  |  | ||||||
|     // count outputs |  | ||||||
|     for (uint32_t i = 0; i < n_tokens_all; ++i) { |  | ||||||
|         n_outputs_all += batch.logits[i] != 0; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (embd_pooled) { |     if (embd_pooled) { | ||||||
|         // require that all tokens are output |         // require that all tokens are output | ||||||
|         if (n_outputs_all != n_tokens_all) { |         if (n_outputs_all != n_tokens_all) { | ||||||
|             LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n", |             LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", | ||||||
|                     __func__, n_outputs_all, n_tokens_all); |                     __func__, n_outputs_all, n_tokens_all); | ||||||
|             return -1; |             return -1; | ||||||
|         } |         } | ||||||
| @@ -1024,7 +991,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|  |  | ||||||
|     // reserve output buffer |     // reserve output buffer | ||||||
|     if (output_reserve(n_outputs_all) < n_outputs_all) { |     if (output_reserve(n_outputs_all) < n_outputs_all) { | ||||||
|         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); |         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); | ||||||
|         return -2; |         return -2; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
| @@ -1063,6 +1030,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|                 pos_min[s] = std::numeric_limits<llama_pos>::max(); |                 pos_min[s] = std::numeric_limits<llama_pos>::max(); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  |             // TODO: fix sequence indexing | ||||||
|             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { | ||||||
|                 const auto & seq_id = ubatch.seq_id[i][0]; |                 const auto & seq_id = ubatch.seq_id[i][0]; | ||||||
|  |  | ||||||
| @@ -1176,14 +1144,14 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|     n_outputs = n_outputs_all; |     n_outputs = n_outputs_all; | ||||||
|  |  | ||||||
|     // set output mappings |     // set output mappings | ||||||
|     { |     if (n_outputs > 0) { | ||||||
|         bool sorted_output = true; |         bool sorted_output = true; | ||||||
|  |  | ||||||
|         auto & out_ids = mstate->out_ids(); |         auto & out_ids = mstate->out_ids(); | ||||||
|  |  | ||||||
|         GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); |         GGML_ASSERT(out_ids.size() == (size_t) n_outputs); | ||||||
|  |  | ||||||
|         for (int64_t i = 0; i < n_outputs_all; ++i) { |         for (int64_t i = 0; i < n_outputs; ++i) { | ||||||
|             int64_t out_id = out_ids[i]; |             int64_t out_id = out_ids[i]; | ||||||
|             output_ids[out_id] = i; |             output_ids[out_id] = i; | ||||||
|             if (out_id != i) { |             if (out_id != i) { | ||||||
| @@ -1195,20 +1163,22 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|         // note: this is mostly relevant for recurrent models atm |         // note: this is mostly relevant for recurrent models atm | ||||||
|         if (!sorted_output) { |         if (!sorted_output) { | ||||||
|             const uint32_t n_vocab = model.vocab.n_tokens(); |             const uint32_t n_vocab = model.vocab.n_tokens(); | ||||||
|             const uint32_t n_embd  = model.hparams.n_embd; |             const uint64_t n_embd  = model.hparams.n_embd; | ||||||
|  |  | ||||||
|             GGML_ASSERT((size_t) n_outputs == out_ids.size()); |             GGML_ASSERT((size_t) n_outputs == out_ids.size()); | ||||||
|  |  | ||||||
|             // TODO: is there something more efficient which also minimizes swaps? |             // TODO: is there something more efficient which also minimizes swaps? | ||||||
|             // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |             // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) | ||||||
|             for (int32_t i = 0; i < n_outputs - 1; ++i) { |             for (uint32_t i = 0; i < n_outputs - 1; ++i) { | ||||||
|                 int32_t j_min = i; |                 uint32_t j_min = i; | ||||||
|                 for (int32_t j = i + 1; j < n_outputs; ++j) { |                 for (uint32_t j = i + 1; j < n_outputs; ++j) { | ||||||
|                     if (out_ids[j] < out_ids[j_min]) { |                     if (out_ids[j] < out_ids[j_min]) { | ||||||
|                         j_min = j; |                         j_min = j; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 if (j_min == i) { continue; } |                 if (j_min == i) { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|                 std::swap(out_ids[i], out_ids[j_min]); |                 std::swap(out_ids[i], out_ids[j_min]); | ||||||
|                 if (logits_size > 0) { |                 if (logits_size > 0) { | ||||||
|                     for (uint32_t k = 0; k < n_vocab; k++) { |                     for (uint32_t k = 0; k < n_vocab; k++) { | ||||||
| @@ -1221,8 +1191,10 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             std::fill(output_ids.begin(), output_ids.end(), -1); |             std::fill(output_ids.begin(), output_ids.end(), -1); | ||||||
|             for (int32_t i = 0; i < n_outputs; ++i) { |  | ||||||
|  |             for (uint32_t i = 0; i < n_outputs; ++i) { | ||||||
|                 output_ids[out_ids[i]] = i; |                 output_ids[out_ids[i]] = i; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -1242,7 +1214,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
| // output | // output | ||||||
| // | // | ||||||
|  |  | ||||||
| int32_t llama_context::output_reserve(int32_t n_outputs) { | uint32_t llama_context::output_reserve(int32_t n_outputs) { | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|     const auto & vocab   = model.vocab; |     const auto & vocab   = model.vocab; | ||||||
|  |  | ||||||
| @@ -1309,7 +1281,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { | |||||||
|     std::fill(output_ids.begin(), output_ids.end(), -1); |     std::fill(output_ids.begin(), output_ids.end(), -1); | ||||||
|  |  | ||||||
|     this->n_outputs = 0; |     this->n_outputs = 0; | ||||||
|     this->n_outputs_max = n_outputs_max; |  | ||||||
|  |  | ||||||
|     return n_outputs_max; |     return n_outputs_max; | ||||||
| } | } | ||||||
| @@ -1800,14 +1771,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { | |||||||
|  |  | ||||||
|         std::vector<int32_t> w_output_pos; |         std::vector<int32_t> w_output_pos; | ||||||
|  |  | ||||||
|         GGML_ASSERT(n_outputs <= n_outputs_max); |  | ||||||
|  |  | ||||||
|         w_output_pos.resize(n_outputs); |         w_output_pos.resize(n_outputs); | ||||||
|  |  | ||||||
|         // build a more compact representation of the output ids |         // build a more compact representation of the output ids | ||||||
|         for (size_t i = 0; i < n_batch(); ++i) { |         for (size_t i = 0; i < n_batch(); ++i) { | ||||||
|             // map an output id to a position in the batch |             // map an output id to a position in the batch | ||||||
|             int32_t pos = output_ids[i]; |             int64_t pos = output_ids[i]; | ||||||
|             if (pos >= 0) { |             if (pos >= 0) { | ||||||
|                 GGML_ASSERT(pos < n_outputs); |                 GGML_ASSERT(pos < n_outputs); | ||||||
|                 w_output_pos[pos] = i; |                 w_output_pos[pos] = i; | ||||||
| @@ -2082,7 +2051,7 @@ void llama_context::opt_epoch_iter( | |||||||
|  |  | ||||||
|         embd_seq.clear(); |         embd_seq.clear(); | ||||||
|  |  | ||||||
|         int64_t n_outputs_all = n_tokens_all; |         uint32_t n_outputs_all = n_tokens_all; | ||||||
|  |  | ||||||
|         auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); |         auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); | ||||||
|         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { | ||||||
| @@ -2092,7 +2061,7 @@ void llama_context::opt_epoch_iter( | |||||||
|  |  | ||||||
|         // reserve output buffer |         // reserve output buffer | ||||||
|         if (output_reserve(n_outputs_all) < n_outputs_all) { |         if (output_reserve(n_outputs_all) < n_outputs_all) { | ||||||
|             LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); |             LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); | ||||||
|             GGML_ABORT("TODO: handle this error"); |             GGML_ABORT("TODO: handle this error"); | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
| #include "llama-batch.h" |  | ||||||
| #include "llama-cparams.h" | #include "llama-cparams.h" | ||||||
| #include "llama-graph.h" | #include "llama-graph.h" | ||||||
| #include "llama-adapter.h" | #include "llama-adapter.h" | ||||||
| @@ -13,6 +12,7 @@ | |||||||
| #include <vector> | #include <vector> | ||||||
|  |  | ||||||
| struct llama_model; | struct llama_model; | ||||||
|  | class llama_batch_allocr; | ||||||
|  |  | ||||||
| class llama_io_read_i; | class llama_io_read_i; | ||||||
| class llama_io_write_i; | class llama_io_write_i; | ||||||
| @@ -102,8 +102,8 @@ struct llama_context { | |||||||
|             llama_memory_state_i * mstate, |             llama_memory_state_i * mstate, | ||||||
|                      ggml_status & ret); |                      ggml_status & ret); | ||||||
|  |  | ||||||
|     int encode(llama_batch & inp_batch); |     int encode(const llama_batch & batch_inp); | ||||||
|     int decode(llama_batch & inp_batch); |     int decode(const llama_batch & batch_inp); | ||||||
|  |  | ||||||
|     // |     // | ||||||
|     // state save/load |     // state save/load | ||||||
| @@ -181,7 +181,7 @@ private: | |||||||
|  |  | ||||||
|     // Make sure enough space is available for outputs. |     // Make sure enough space is available for outputs. | ||||||
|     // Returns max number of outputs for which space was reserved. |     // Returns max number of outputs for which space was reserved. | ||||||
|     int32_t output_reserve(int32_t n_outputs); |     uint32_t output_reserve(int32_t n_outputs); | ||||||
|  |  | ||||||
|     // |     // | ||||||
|     // graph |     // graph | ||||||
| @@ -246,8 +246,10 @@ private: | |||||||
|     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE |     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE | ||||||
|     std::map<llama_seq_id, std::vector<float>> embd_seq; |     std::map<llama_seq_id, std::vector<float>> embd_seq; | ||||||
|  |  | ||||||
|     int32_t n_outputs     = 0; // number of actually-used outputs in the current ubatch or last logical batch |     // reuse the batch_allocr to avoid unnecessary memory allocations | ||||||
|     int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers |     std::unique_ptr<llama_batch_allocr> batch_allocr; | ||||||
|  |  | ||||||
|  |     uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch | ||||||
|  |  | ||||||
|     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers |     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers | ||||||
|  |  | ||||||
|   | |||||||
| @@ -139,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { | |||||||
|  |  | ||||||
|         std::vector<uint64_t> sum(n_tokens, 0); |         std::vector<uint64_t> sum(n_tokens, 0); | ||||||
|  |  | ||||||
|  |         // TODO: fix indexing [UBATCH_IDX] | ||||||
|         for (int s = 0; s < n_seqs; ++s) { |         for (int s = 0; s < n_seqs; ++s) { | ||||||
|             const llama_seq_id seq_id = ubatch->seq_id[s][0]; |             const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||||||
|  |  | ||||||
| @@ -156,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         // TODO: fix indexing [UBATCH_IDX] | ||||||
|         for (int s = 0; s < n_seqs; ++s) { |         for (int s = 0; s < n_seqs; ++s) { | ||||||
|             const llama_seq_id seq_id = ubatch->seq_id[s][0]; |             const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||||||
|  |  | ||||||
| @@ -180,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | |||||||
|         uint32_t * data = (uint32_t *) cls->data; |         uint32_t * data = (uint32_t *) cls->data; | ||||||
|         memset(cls->data, 0, n_tokens * ggml_element_size(cls)); |         memset(cls->data, 0, n_tokens * ggml_element_size(cls)); | ||||||
|  |  | ||||||
|  |         // TODO: fix indexing [UBATCH_IDX] | ||||||
|         for (int s = 0; s < n_seqs; ++s) { |         for (int s = 0; s < n_seqs; ++s) { | ||||||
|             const llama_seq_id seq_id = ubatch->seq_id[s][0]; |             const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||||||
|  |  | ||||||
| @@ -210,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | |||||||
|         std::vector<int> last_pos(n_tokens, -1); |         std::vector<int> last_pos(n_tokens, -1); | ||||||
|         std::vector<int> last_row(n_tokens, -1); |         std::vector<int> last_row(n_tokens, -1); | ||||||
|  |  | ||||||
|  |         // TODO: fix indexing [UBATCH_IDX] | ||||||
|         for (int s = 0; s < n_seqs; ++s) { |         for (int s = 0; s < n_seqs; ++s) { | ||||||
|             const llama_seq_id seq_id = ubatch->seq_id[s][0]; |             const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||||||
|  |  | ||||||
| @@ -283,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { | |||||||
|                                 const int32_t ti = s0*n_seq_tokens + i; |                                 const int32_t ti = s0*n_seq_tokens + i; | ||||||
|                                 float f = -INFINITY; |                                 float f = -INFINITY; | ||||||
|  |  | ||||||
|  |                                 // TODO: fix indexing [UBATCH_IDX] | ||||||
|                                 for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { |                                 for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { | ||||||
|                                     if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { |                                     if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { | ||||||
|                                         if (hparams.use_alibi) { |                                         if (hparams.use_alibi) { | ||||||
| @@ -322,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { | |||||||
|                                 const int32_t ti = s0*n_seq_tokens + i; |                                 const int32_t ti = s0*n_seq_tokens + i; | ||||||
|                                 float f = -INFINITY; |                                 float f = -INFINITY; | ||||||
|  |  | ||||||
|  |                                 // TODO: fix indexing [UBATCH_IDX] | ||||||
|                                 for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { |                                 for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { | ||||||
|                                     if (ubatch->seq_id[s0][s] == seq_id) { |                                     if (ubatch->seq_id[s0][s] == seq_id) { | ||||||
|                                         if (hparams.use_alibi) { |                                         if (hparams.use_alibi) { | ||||||
| @@ -377,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { | |||||||
|             for (int j = 0; j < n_tokens; ++j) { |             for (int j = 0; j < n_tokens; ++j) { | ||||||
|                 for (int i = 0; i < n_enc; ++i) { |                 for (int i = 0; i < n_enc; ++i) { | ||||||
|                     float f = -INFINITY; |                     float f = -INFINITY; | ||||||
|  |                     // TODO: fix indexing [UBATCH_IDX] | ||||||
|                     for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { |                     for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { | ||||||
|                         const llama_seq_id seq_id = ubatch->seq_id[j][s]; |                         const llama_seq_id seq_id = ubatch->seq_id[j][s]; | ||||||
|                         if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { |                         if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { | ||||||
|   | |||||||
| @@ -378,7 +378,7 @@ struct llm_graph_params { | |||||||
|     const llama_memory_state_i * mstate; |     const llama_memory_state_i * mstate; | ||||||
|     const llama_cross          * cross; |     const llama_cross          * cross; | ||||||
|  |  | ||||||
|     int32_t n_outputs; |     uint32_t n_outputs; | ||||||
|  |  | ||||||
|     const llm_graph_cb & cb; |     const llm_graph_cb & cb; | ||||||
| }; | }; | ||||||
| @@ -412,8 +412,8 @@ struct llm_graph_context { | |||||||
|     const float norm_eps; |     const float norm_eps; | ||||||
|     const float norm_rms_eps; |     const float norm_rms_eps; | ||||||
|  |  | ||||||
|     const int32_t n_tokens; |     const int64_t n_tokens; | ||||||
|     const int32_t n_outputs; |     const int64_t n_outputs; | ||||||
|     const int32_t n_ctx_orig; // yarn |     const int32_t n_ctx_orig; // yarn | ||||||
|  |  | ||||||
|     const enum llama_pooling_type pooling_type; |     const enum llama_pooling_type pooling_type; | ||||||
|   | |||||||
| @@ -674,6 +674,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch | |||||||
|  |  | ||||||
|             cells.pos_set(head_cur + idx, ubatch.pos[idx]); |             cells.pos_set(head_cur + idx, ubatch.pos[idx]); | ||||||
|  |  | ||||||
|  |             // TODO: fix indexing [UBATCH_IDX] | ||||||
|             for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { |             for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { | ||||||
|                 cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]); |                 cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]); | ||||||
|             } |             } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov