mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	return output ID from llama_batch_ext_add/set
This commit is contained in:
		| @@ -606,7 +606,7 @@ struct common_batch { | |||||||
|     } |     } | ||||||
|     void set_logits_last() { |     void set_logits_last() { | ||||||
|         if (!tokens.empty()) { |         if (!tokens.empty()) { | ||||||
|             llama_batch_ext_set_logits_last(batch.get()); |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|             tokens.back().logits = true; |             tokens.back().logits = true; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -122,7 +122,7 @@ int main(int argc, char ** argv) { | |||||||
|                         llama_batch_ext_add_text(batch, 0, i, &j, 1, false); |                         llama_batch_ext_add_text(batch, 0, i, &j, 1, false); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 llama_batch_ext_set_logits_last(batch); |                 llama_batch_ext_set_output_last(batch); | ||||||
|  |  | ||||||
|                 const auto t_pp_start = ggml_time_us(); |                 const auto t_pp_start = ggml_time_us(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // llama_decode will output logits only for the last token of the prompt |     // llama_decode will output logits only for the last token of the prompt | ||||||
|     llama_batch_ext_set_logits_last(batch); |     llama_batch_ext_set_output_last(batch); | ||||||
|  |  | ||||||
|     if (llama_decode_ext(ctx, batch) != 0) { |     if (llama_decode_ext(ctx, batch) != 0) { | ||||||
|         LOG_ERR("%s: llama_decode() failed\n", __func__); |         LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||||
|   | |||||||
| @@ -900,7 +900,7 @@ extern "C" { | |||||||
|     // |     // | ||||||
|     DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( |     DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( | ||||||
|                   llama_token * tokens, |                   llama_token * tokens, | ||||||
|                       int32_t   n_tokens), "use llama_batch_ext API instead"); |                       int32_t   n_tokens), "use llama_batch_ext_init_from_text instead"); | ||||||
|  |  | ||||||
|     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens |     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens | ||||||
|     // Each token can be assigned up to n_seq_max sequence ids |     // Each token can be assigned up to n_seq_max sequence ids | ||||||
| @@ -912,7 +912,7 @@ extern "C" { | |||||||
|     DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( |     DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( | ||||||
|                     int32_t n_tokens, |                     int32_t n_tokens, | ||||||
|                     int32_t embd, |                     int32_t embd, | ||||||
|                     int32_t n_seq_max), "use llama_batch_ext API instead"); |                     int32_t n_seq_max), "use llama_batch_ext_init instead"); | ||||||
|  |  | ||||||
|     // Frees a batch of tokens allocated with llama_batch_init() |     // Frees a batch of tokens allocated with llama_batch_init() | ||||||
|     DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), |     DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), | ||||||
| @@ -950,28 +950,32 @@ extern "C" { | |||||||
|  |  | ||||||
|     // Add text tokens to the batch |     // Add text tokens to the batch | ||||||
|     // Return values: |     // Return values: | ||||||
|     //  0 : success |  | ||||||
|     // -1 : not enough space in the batch |     // -1 : not enough space in the batch | ||||||
|     // -2 : embd is already set, cannot add text tokens |     // -2 : embd is already set, cannot add text tokens | ||||||
|  |     // otherwise, returns the output ID | ||||||
|     LLAMA_API int32_t llama_batch_ext_add_text( |     LLAMA_API int32_t llama_batch_ext_add_text( | ||||||
|         struct llama_batch_ext * batch, |         struct llama_batch_ext * batch, | ||||||
|                    llama_token   token, |                    llama_token   token, | ||||||
|                      llama_pos   pos, |                      llama_pos   pos, | ||||||
|             const llama_seq_id * seq_ids, |             const llama_seq_id * seq_ids, | ||||||
|                         size_t   n_seq_ids, |                         size_t   n_seq_ids, | ||||||
|                          float   logits); |                           bool   output); | ||||||
|  |  | ||||||
|     // Set logits for the token in the ith sequence |     // Set output (logits/embeddings) for the token in the ith sequence | ||||||
|     // If pos == -1, logits will be set for the all tokens |     // If pos == -1, output will be set for the all tokens | ||||||
|     // Returns -1 if the token is not in the batch |     // Return values: | ||||||
|     LLAMA_API int32_t llama_batch_ext_set_logits( |     // -1 : the token is not in the batch | ||||||
|  |     // otherwise, returns the output ID | ||||||
|  |     LLAMA_API int32_t llama_batch_ext_set_output( | ||||||
|         struct llama_batch_ext * batch, |         struct llama_batch_ext * batch, | ||||||
|                      llama_pos   pos, |                      llama_pos   pos, | ||||||
|                   llama_seq_id   seq_id); |                   llama_seq_id   seq_id); | ||||||
|  |  | ||||||
|     // Set logits for the last added token |     // Set output (logits/embeddings) for the last added token | ||||||
|     // Returns -1 if there is no tokens in the batch |     // Return values: | ||||||
|     LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); |     // -1 : the batch is empty | ||||||
|  |     // otherwise, returns the output ID | ||||||
|  |     LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     // Get a "view" from a number of tokens offset |     // Get a "view" from a number of tokens offset | ||||||
|     // Return returned batch must be freed with llama_batch_free() |     // Return returned batch must be freed with llama_batch_free() | ||||||
|   | |||||||
| @@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text( | |||||||
|                  llama_pos   pos, |                  llama_pos   pos, | ||||||
|         const llama_seq_id * seq_ids, |         const llama_seq_id * seq_ids, | ||||||
|                     size_t   n_seq_ids, |                     size_t   n_seq_ids, | ||||||
|                      float   logits) { |                       bool   output) { | ||||||
|     if (batch->n_tokens + 1 > batch->max_tokens) { |     if (batch->n_tokens + 1 > batch->max_tokens) { | ||||||
|         return -1; // llama_batch size exceeded |         return -1; // llama_batch size exceeded | ||||||
|     } |     } | ||||||
|     if (batch->embd) { |     if (batch->embd) { | ||||||
|         return -2; // embd is already set, cannot add text tokens |         return -2; // embd is already set, cannot add text tokens | ||||||
|     } |     } | ||||||
|     batch->token   [batch->n_tokens] = token; |     const int32_t output_id = batch->n_tokens; | ||||||
|     batch->pos     [batch->n_tokens] = pos; |     batch->token   [output_id] = token; | ||||||
|     batch->n_seq_id[batch->n_tokens] = n_seq_ids; |     batch->pos     [output_id] = pos; | ||||||
|  |     batch->n_seq_id[output_id] = n_seq_ids; | ||||||
|     for (size_t j = 0; j < n_seq_ids; j++) { |     for (size_t j = 0; j < n_seq_ids; j++) { | ||||||
|         batch->seq_id[batch->n_tokens][j] = seq_ids[j]; |         batch->seq_id[batch->n_tokens][j] = seq_ids[j]; | ||||||
|     } |     } | ||||||
|     batch->logits  [batch->n_tokens] = logits; |     batch->logits  [output_id] = output; | ||||||
|     batch->n_tokens++; |     batch->n_tokens++; | ||||||
|     return 0; |     return output_id; | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_batch_ext_set_logits( | int32_t llama_batch_ext_set_output( | ||||||
|     struct llama_batch_ext * batch, |     struct llama_batch_ext * batch, | ||||||
|                  llama_pos   pos, |                  llama_pos   pos, | ||||||
|               llama_seq_id   seq_id) { |               llama_seq_id   seq_id) { | ||||||
| @@ -439,7 +440,7 @@ int32_t llama_batch_ext_set_logits( | |||||||
|                 // found the sequence |                 // found the sequence | ||||||
|                 if (pos == -1 || pos == batch->pos[i]) { |                 if (pos == -1 || pos == batch->pos[i]) { | ||||||
|                     batch->logits[i] = true; |                     batch->logits[i] = true; | ||||||
|                     return 0; |                     return i; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -447,12 +448,13 @@ int32_t llama_batch_ext_set_logits( | |||||||
|     return -1; // not found |     return -1; // not found | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { | int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) { | ||||||
|     if (batch->n_tokens == 0) { |     if (batch->n_tokens == 0) { | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|     batch->logits[batch->n_tokens - 1] = true; |     const int32_t output_id = batch->n_tokens - 1; | ||||||
|     return 0; |     batch->logits[output_id] = true; | ||||||
|  |     return output_id; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_batch_ext_clear(struct llama_batch_ext * batch) { | void llama_batch_ext_clear(struct llama_batch_ext * batch) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen