mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	return output ID from llama_batch_ext_add/set
This commit is contained in:
		| @@ -606,7 +606,7 @@ struct common_batch { | ||||
|     } | ||||
|     void set_logits_last() { | ||||
|         if (!tokens.empty()) { | ||||
|             llama_batch_ext_set_logits_last(batch.get()); | ||||
|             llama_batch_ext_set_output_last(batch.get()); | ||||
|             tokens.back().logits = true; | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -122,7 +122,7 @@ int main(int argc, char ** argv) { | ||||
|                         llama_batch_ext_add_text(batch, 0, i, &j, 1, false); | ||||
|                     } | ||||
|                 } | ||||
|                 llama_batch_ext_set_logits_last(batch); | ||||
|                 llama_batch_ext_set_output_last(batch); | ||||
|  | ||||
|                 const auto t_pp_start = ggml_time_us(); | ||||
|  | ||||
|   | ||||
| @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     // llama_decode will output logits only for the last token of the prompt | ||||
|     llama_batch_ext_set_logits_last(batch); | ||||
|     llama_batch_ext_set_output_last(batch); | ||||
|  | ||||
|     if (llama_decode_ext(ctx, batch) != 0) { | ||||
|         LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||
|   | ||||
| @@ -900,7 +900,7 @@ extern "C" { | ||||
|     // | ||||
|     DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( | ||||
|                   llama_token * tokens, | ||||
|                       int32_t   n_tokens), "use llama_batch_ext API instead"); | ||||
|                       int32_t   n_tokens), "use llama_batch_ext_init_from_text instead"); | ||||
|  | ||||
|     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens | ||||
|     // Each token can be assigned up to n_seq_max sequence ids | ||||
| @@ -912,7 +912,7 @@ extern "C" { | ||||
|     DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( | ||||
|                     int32_t n_tokens, | ||||
|                     int32_t embd, | ||||
|                     int32_t n_seq_max), "use llama_batch_ext API instead"); | ||||
|                     int32_t n_seq_max), "use llama_batch_ext_init instead"); | ||||
|  | ||||
|     // Frees a batch of tokens allocated with llama_batch_init() | ||||
|     DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), | ||||
| @@ -950,28 +950,32 @@ extern "C" { | ||||
|  | ||||
|     // Add text tokens to the batch | ||||
|     // Return values: | ||||
|     //  0 : success | ||||
|     // -1 : not enough space in the batch | ||||
|     // -2 : embd is already set, cannot add text tokens | ||||
|     // otherwise, returns the output ID | ||||
|     LLAMA_API int32_t llama_batch_ext_add_text( | ||||
|         struct llama_batch_ext * batch, | ||||
|                    llama_token   token, | ||||
|                      llama_pos   pos, | ||||
|             const llama_seq_id * seq_ids, | ||||
|                         size_t   n_seq_ids, | ||||
|                          float   logits); | ||||
|                           bool   output); | ||||
|  | ||||
|     // Set logits for the token in the ith sequence | ||||
|     // If pos == -1, logits will be set for the all tokens | ||||
|     // Returns -1 if the token is not in the batch | ||||
|     LLAMA_API int32_t llama_batch_ext_set_logits( | ||||
|     // Set output (logits/embeddings) for the token in the ith sequence | ||||
|     // If pos == -1, output will be set for the all tokens | ||||
|     // Return values: | ||||
|     // -1 : the token is not in the batch | ||||
|     // otherwise, returns the output ID | ||||
|     LLAMA_API int32_t llama_batch_ext_set_output( | ||||
|         struct llama_batch_ext * batch, | ||||
|                      llama_pos   pos, | ||||
|                   llama_seq_id   seq_id); | ||||
|  | ||||
|     // Set logits for the last added token | ||||
|     // Returns -1 if there is no tokens in the batch | ||||
|     LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); | ||||
|     // Set output (logits/embeddings) for the last added token | ||||
|     // Return values: | ||||
|     // -1 : the batch is empty | ||||
|     // otherwise, returns the output ID | ||||
|     LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); | ||||
|  | ||||
|     // Get a "view" from a number of tokens offset | ||||
|     // Return returned batch must be freed with llama_batch_free() | ||||
|   | ||||
| @@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text( | ||||
|                  llama_pos   pos, | ||||
|         const llama_seq_id * seq_ids, | ||||
|                     size_t   n_seq_ids, | ||||
|                      float   logits) { | ||||
|                       bool   output) { | ||||
|     if (batch->n_tokens + 1 > batch->max_tokens) { | ||||
|         return -1; // llama_batch size exceeded | ||||
|     } | ||||
|     if (batch->embd) { | ||||
|         return -2; // embd is already set, cannot add text tokens | ||||
|     } | ||||
|     batch->token   [batch->n_tokens] = token; | ||||
|     batch->pos     [batch->n_tokens] = pos; | ||||
|     batch->n_seq_id[batch->n_tokens] = n_seq_ids; | ||||
|     const int32_t output_id = batch->n_tokens; | ||||
|     batch->token   [output_id] = token; | ||||
|     batch->pos     [output_id] = pos; | ||||
|     batch->n_seq_id[output_id] = n_seq_ids; | ||||
|     for (size_t j = 0; j < n_seq_ids; j++) { | ||||
|         batch->seq_id[batch->n_tokens][j] = seq_ids[j]; | ||||
|     } | ||||
|     batch->logits  [batch->n_tokens] = logits; | ||||
|     batch->logits  [output_id] = output; | ||||
|     batch->n_tokens++; | ||||
|     return 0; | ||||
|     return output_id; | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_ext_set_logits( | ||||
| int32_t llama_batch_ext_set_output( | ||||
|     struct llama_batch_ext * batch, | ||||
|                  llama_pos   pos, | ||||
|               llama_seq_id   seq_id) { | ||||
| @@ -439,7 +440,7 @@ int32_t llama_batch_ext_set_logits( | ||||
|                 // found the sequence | ||||
|                 if (pos == -1 || pos == batch->pos[i]) { | ||||
|                     batch->logits[i] = true; | ||||
|                     return 0; | ||||
|                     return i; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @@ -447,12 +448,13 @@ int32_t llama_batch_ext_set_logits( | ||||
|     return -1; // not found | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { | ||||
| int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) { | ||||
|     if (batch->n_tokens == 0) { | ||||
|         return -1; | ||||
|     } | ||||
|     batch->logits[batch->n_tokens - 1] = true; | ||||
|     return 0; | ||||
|     const int32_t output_id = batch->n_tokens - 1; | ||||
|     batch->logits[output_id] = true; | ||||
|     return output_id; | ||||
| } | ||||
|  | ||||
| void llama_batch_ext_clear(struct llama_batch_ext * batch) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen