mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	context : simplify output counting logic during decode
ggml-ci
This commit is contained in:
		| @@ -306,9 +306,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 | |||||||
|         batch.seq_id = seq_id.data(); |         batch.seq_id = seq_id.data(); | ||||||
|     } |     } | ||||||
|     if (!batch.logits) { |     if (!batch.logits) { | ||||||
|         logits.resize(batch.n_tokens); |         // by default return the output only for the last token | ||||||
|         logits[logits.size() - 1] = true; |         output.resize(batch.n_tokens); | ||||||
|         batch.logits = logits.data(); |         output[output.size() - 1] = true; | ||||||
|  |         batch.logits = output.data(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -85,7 +85,7 @@ struct llama_batch_allocr { | |||||||
|     std::vector<llama_pos>      pos; |     std::vector<llama_pos>      pos; | ||||||
|     std::vector<int32_t>        n_seq_id; |     std::vector<int32_t>        n_seq_id; | ||||||
|     std::vector<llama_seq_id *> seq_id; |     std::vector<llama_seq_id *> seq_id; | ||||||
|     std::vector<int8_t>         logits; |     std::vector<int8_t>         output; | ||||||
|  |  | ||||||
|     // optionally fulfill the batch returned by llama_batch_get_one |     // optionally fulfill the batch returned by llama_batch_get_one | ||||||
|     llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); |     llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); | ||||||
|   | |||||||
| @@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|         t_compute_start_us = ggml_time_us(); |         t_compute_start_us = ggml_time_us(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // TODO: this clear of the buffer can easily be forgotten - need something better | ||||||
|     embd_seq.clear(); |     embd_seq.clear(); | ||||||
|  |  | ||||||
|     n_queued_tokens += n_tokens; |     n_queued_tokens += n_tokens; | ||||||
| @@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens | ||||||
|  |     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; | ||||||
|  |  | ||||||
|  |     int64_t n_outputs_all = 0; | ||||||
|  |  | ||||||
|  |     // count outputs | ||||||
|  |     for (uint32_t i = 0; i < n_tokens_all; ++i) { | ||||||
|  |         n_outputs_all += batch.logits[i] != 0; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (embd_pooled) { | ||||||
|  |         // require that all tokens are output | ||||||
|  |         if (n_outputs_all != n_tokens_all) { | ||||||
|  |             LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n", | ||||||
|  |                     __func__, n_outputs_all, n_tokens_all); | ||||||
|  |             return -1; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     GGML_ASSERT(n_tokens_all <= cparams.n_batch); |     GGML_ASSERT(n_tokens_all <= cparams.n_batch); | ||||||
|  |  | ||||||
|     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); |     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); | ||||||
| @@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|     } |     } | ||||||
|     n_queued_tokens += n_tokens_all; |     n_queued_tokens += n_tokens_all; | ||||||
|  |  | ||||||
|     // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens |     // TODO: this clear of the buffer can easily be forgotten - need something better | ||||||
|     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; |  | ||||||
|  |  | ||||||
|     embd_seq.clear(); |     embd_seq.clear(); | ||||||
|  |  | ||||||
|     int64_t n_outputs_all = 0; |  | ||||||
|  |  | ||||||
|     // count outputs |  | ||||||
|     if (batch.logits && !embd_pooled) { |  | ||||||
|         for (uint32_t i = 0; i < n_tokens_all; ++i) { |  | ||||||
|             n_outputs_all += batch.logits[i] != 0; |  | ||||||
|         } |  | ||||||
|     } else if (embd_pooled) { |  | ||||||
|         n_outputs_all = n_tokens_all; |  | ||||||
|     } else { |  | ||||||
|         // keep last output only |  | ||||||
|         n_outputs_all = 1; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     bool did_optimize = false; |     bool did_optimize = false; | ||||||
|  |  | ||||||
|     // handle any pending defrags/shifts |     // handle any pending defrags/shifts | ||||||
| @@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|     do { |     do { | ||||||
|         const auto & ubatch = mstate->get_ubatch(); |         const auto & ubatch = mstate->get_ubatch(); | ||||||
|  |  | ||||||
|         // count the outputs in this u_batch |         // count the outputs in this ubatch | ||||||
|         { |         { | ||||||
|             int32_t n_outputs_new = 0; |             int32_t n_outputs_new = 0; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov