mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	batch : remove logits_all flag
ggml-ci
This commit is contained in:
		| @@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s | |||||||
|             ubatch.seq_id = batch->seq_id + seq.offset; |             ubatch.seq_id = batch->seq_id + seq.offset; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     if (logits_all) { |     if (batch->logits) { | ||||||
|         for (size_t i = 0; i < length; ++i) { |  | ||||||
|             ubatch.output[ubatch.n_tokens + i] = 1; |  | ||||||
|             out_ids.push_back(ids[seq.offset + i]); |  | ||||||
|         } |  | ||||||
|     } else if (batch->logits) { |  | ||||||
|         if (ubatch.equal_seqs) { |         if (ubatch.equal_seqs) { | ||||||
|             for (size_t i = 0; i < length; ++i) { |             for (size_t i = 0; i < length; ++i) { | ||||||
|                 size_t id = ids[seq.offset + i]; |                 size_t id = ids[seq.offset + i]; | ||||||
| @@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { | |||||||
|     return ubatch; |     return ubatch; | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { | llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) { | ||||||
|     GGML_ASSERT(batch.n_tokens >= 0); |     GGML_ASSERT(batch.n_tokens >= 0); | ||||||
|     this->batch = &batch; |     this->batch = &batch; | ||||||
|     this->n_embd = n_embd; |     this->n_embd = n_embd; | ||||||
|     this->logits_all = logits_all; |  | ||||||
|  |  | ||||||
|     n_tokens = batch.n_tokens; |     n_tokens = batch.n_tokens; | ||||||
|     ids.resize(n_tokens); |     ids.resize(n_tokens); | ||||||
|   | |||||||
| @@ -39,8 +39,6 @@ struct llama_sbatch { | |||||||
|  |  | ||||||
|     size_t n_embd; |     size_t n_embd; | ||||||
|  |  | ||||||
|     bool logits_all; // TODO: remove once lctx.logits_all is removed too |  | ||||||
|  |  | ||||||
|     // sorted indices into the batch |     // sorted indices into the batch | ||||||
|     std::vector<int64_t> ids; |     std::vector<int64_t> ids; | ||||||
|     // batch indices of the output |     // batch indices of the output | ||||||
| @@ -76,7 +74,7 @@ struct llama_sbatch { | |||||||
|     llama_ubatch split_seq(size_t n_ubatch); |     llama_ubatch split_seq(size_t n_ubatch); | ||||||
|  |  | ||||||
|     llama_sbatch() = default; |     llama_sbatch() = default; | ||||||
|     llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); |     llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // temporary allocate memory for the input batch if needed | // temporary allocate memory for the input batch if needed | ||||||
|   | |||||||
| @@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|  |  | ||||||
|     const int64_t n_embd = hparams.n_embd; |     const int64_t n_embd = hparams.n_embd; | ||||||
|  |  | ||||||
|     llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); |     llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true); | ||||||
|  |  | ||||||
|     const llama_ubatch ubatch = sbatch.split_simple(n_tokens); |     const llama_ubatch ubatch = sbatch.split_simple(n_tokens); | ||||||
|  |  | ||||||
| @@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|     llama_memory_state_ptr mstate; |     llama_memory_state_ptr mstate; | ||||||
|  |  | ||||||
|     while (true) { |     while (true) { | ||||||
|         mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); |         mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); | ||||||
|         if (!mstate) { |         if (!mstate) { | ||||||
|             return -2; |             return -2; | ||||||
|         } |         } | ||||||
| @@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter( | |||||||
|  |  | ||||||
|         int64_t n_outputs_all = n_tokens_all; |         int64_t n_outputs_all = n_tokens_all; | ||||||
|  |  | ||||||
|         auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); |         auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); | ||||||
|         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { | ||||||
|             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); |             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); | ||||||
|             break; |             break; | ||||||
|   | |||||||
| @@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { | |||||||
|     return result; |     return result; | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { | llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { | ||||||
|     GGML_UNUSED(embd_pooled); |     GGML_UNUSED(embd_pooled); | ||||||
|  |  | ||||||
|     auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); |     auto sbatch = llama_sbatch(batch, hparams.n_embd, false); | ||||||
|  |  | ||||||
|     std::vector<llama_ubatch> ubatches; |     std::vector<llama_ubatch> ubatches; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -32,8 +32,7 @@ public: | |||||||
|     llama_memory_state_ptr init_batch( |     llama_memory_state_ptr init_batch( | ||||||
|             const llama_batch & batch, |             const llama_batch & batch, | ||||||
|             uint32_t n_ubatch, |             uint32_t n_ubatch, | ||||||
|             bool embd_pooled, |             bool embd_pooled) override; | ||||||
|             bool logits_all) override; |  | ||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { | |||||||
|     return kv_swa->seq_pos_max(seq_id); |     return kv_swa->seq_pos_max(seq_id); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { | llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { | ||||||
|     GGML_UNUSED(embd_pooled); |     GGML_UNUSED(embd_pooled); | ||||||
|  |  | ||||||
|     // first try simple split |     // first try simple split | ||||||
|     do { |     do { | ||||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); |         auto sbatch = llama_sbatch(batch, hparams.n_embd, true); | ||||||
|  |  | ||||||
|         std::vector<llama_ubatch> ubatches; |         std::vector<llama_ubatch> ubatches; | ||||||
|  |  | ||||||
| @@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch | |||||||
|  |  | ||||||
|     // if it fails, try equal split |     // if it fails, try equal split | ||||||
|     do { |     do { | ||||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); |         auto sbatch = llama_sbatch(batch, hparams.n_embd, false); | ||||||
|  |  | ||||||
|         std::vector<llama_ubatch> ubatches; |         std::vector<llama_ubatch> ubatches; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -34,8 +34,7 @@ public: | |||||||
|     llama_memory_state_ptr init_batch( |     llama_memory_state_ptr init_batch( | ||||||
|             const llama_batch & batch, |             const llama_batch & batch, | ||||||
|             uint32_t n_ubatch, |             uint32_t n_ubatch, | ||||||
|             bool embd_pooled, |             bool embd_pooled) override; | ||||||
|             bool logits_all) override; |  | ||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | |||||||
| llama_memory_state_ptr llama_kv_cache_unified::init_batch( | llama_memory_state_ptr llama_kv_cache_unified::init_batch( | ||||||
|             const llama_batch & batch, |             const llama_batch & batch, | ||||||
|             uint32_t n_ubatch, |             uint32_t n_ubatch, | ||||||
|             bool embd_pooled, |             bool embd_pooled) { | ||||||
|             bool logits_all) { |  | ||||||
|     GGML_UNUSED(embd_pooled); |     GGML_UNUSED(embd_pooled); | ||||||
|  |  | ||||||
|     do { |     do { | ||||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); |         auto sbatch = llama_sbatch(batch, hparams.n_embd, true); | ||||||
|  |  | ||||||
|         std::vector<llama_ubatch> ubatches; |         std::vector<llama_ubatch> ubatches; | ||||||
|         while (sbatch.n_tokens > 0) { |         while (sbatch.n_tokens > 0) { | ||||||
|   | |||||||
| @@ -59,8 +59,7 @@ public: | |||||||
|     llama_memory_state_ptr init_batch( |     llama_memory_state_ptr init_batch( | ||||||
|             const llama_batch & batch, |             const llama_batch & batch, | ||||||
|             uint32_t n_ubatch, |             uint32_t n_ubatch, | ||||||
|             bool embd_pooled, |             bool embd_pooled) override; | ||||||
|             bool logits_all) override; |  | ||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -73,8 +73,7 @@ struct llama_memory_i { | |||||||
|     virtual llama_memory_state_ptr init_batch( |     virtual llama_memory_state_ptr init_batch( | ||||||
|             const llama_batch & batch, |             const llama_batch & batch, | ||||||
|             uint32_t n_ubatch, |             uint32_t n_ubatch, | ||||||
|             bool embd_pooled, |             bool embd_pooled) = 0; | ||||||
|             bool logits_all) = 0; |  | ||||||
|  |  | ||||||
|     // simulate full cache, used for allocating worst-case compute buffers |     // simulate full cache, used for allocating worst-case compute buffers | ||||||
|     virtual llama_memory_state_ptr init_full() = 0; |     virtual llama_memory_state_ptr init_full() = 0; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov