mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	move to llama_batch_ext
This commit is contained in:
		| @@ -1610,20 +1610,29 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons | |||||||
| // Batch utils | // Batch utils | ||||||
| // | // | ||||||
|  |  | ||||||
| void common_batch_clear(struct llama_batch * batch) { | // DEPRECATED | ||||||
|     llama_batch_clear(batch); | void common_batch_clear(struct llama_batch & batch) { | ||||||
|  |     batch.n_tokens = 0; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // DEPRECATED | ||||||
| void common_batch_add( | void common_batch_add( | ||||||
|                  struct llama_batch * batch, |                  struct llama_batch & batch, | ||||||
|                         llama_token   id, |                         llama_token   id, | ||||||
|                           llama_pos   pos, |                           llama_pos   pos, | ||||||
|     const std::vector<llama_seq_id> & seq_ids, |     const std::vector<llama_seq_id> & seq_ids, | ||||||
|                                bool   logits) { |                                bool   logits) { | ||||||
|     int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits); |     GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); | ||||||
|     if (res == -1) { |  | ||||||
|         LOG_ERR("%s: llama_batch size exceeded\n", __func__); |     batch.token   [batch.n_tokens] = id; | ||||||
|  |     batch.pos     [batch.n_tokens] = pos; | ||||||
|  |     batch.n_seq_id[batch.n_tokens] = seq_ids.size(); | ||||||
|  |     for (size_t i = 0; i < seq_ids.size(); ++i) { | ||||||
|  |         batch.seq_id[batch.n_tokens][i] = seq_ids[i]; | ||||||
|     } |     } | ||||||
|  |     batch.logits  [batch.n_tokens] = logits; | ||||||
|  |  | ||||||
|  |     batch.n_tokens++; | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // | ||||||
|   | |||||||
| @@ -554,10 +554,12 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap | |||||||
| // Batch utils | // Batch utils | ||||||
| // | // | ||||||
|  |  | ||||||
| void common_batch_clear(struct llama_batch * batch); | // DEPRECATED | ||||||
|  | void common_batch_clear(struct llama_batch & batch); | ||||||
|  |  | ||||||
|  | // DEPRECATED | ||||||
| void common_batch_add( | void common_batch_add( | ||||||
|                  struct llama_batch * batch, |                  struct llama_batch & batch, | ||||||
|                         llama_token   id, |                         llama_token   id, | ||||||
|                           llama_pos   pos, |                           llama_pos   pos, | ||||||
|     const std::vector<llama_seq_id> & seq_ids, |     const std::vector<llama_seq_id> & seq_ids, | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ struct common_speculative { | |||||||
|     struct llama_context * ctx; |     struct llama_context * ctx; | ||||||
|     struct common_sampler * smpl; |     struct common_sampler * smpl; | ||||||
|  |  | ||||||
|     llama_batch * batch; |     llama_batch batch; | ||||||
|     llama_tokens prompt; |     llama_tokens prompt; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( | |||||||
|     auto * result = new common_speculative { |     auto * result = new common_speculative { | ||||||
|         /* .ctx    = */ ctx_dft, |         /* .ctx    = */ ctx_dft, | ||||||
|         /* .smpl   = */ nullptr, |         /* .smpl   = */ nullptr, | ||||||
|         /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 1), |         /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), | ||||||
|         /* .prompt = */ {}, |         /* .prompt = */ {}, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
| @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // we should rarely end-up here during normal decoding |     // we should rarely end-up here during normal decoding | ||||||
|     if (llama_batch_get_n_tokens(batch) > 0) { |     if (batch.n_tokens > 0) { | ||||||
|         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); |         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); | ||||||
|  |  | ||||||
|         llama_decode(ctx, batch); |         llama_decode(ctx, batch); | ||||||
|   | |||||||
| @@ -24,12 +24,12 @@ struct llama_adapter_lora_deleter { | |||||||
|     void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } |     void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct llama_batch_deleter { | struct llama_batch_ext_deleter { | ||||||
|     void operator()(llama_batch * batch) { llama_batch_free(batch); } |     void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr; | typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr; | ||||||
| typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr; | typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr; | ||||||
| typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr; | typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr; | ||||||
| typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr; | typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr; | ||||||
| typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr; | typedef std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> llama_batch_ext_ptr; | ||||||
|   | |||||||
							
								
								
									
										109
									
								
								include/llama.h
									
									
									
									
									
								
							
							
						
						
									
										109
									
								
								include/llama.h
									
									
									
									
									
								
							| @@ -231,9 +231,38 @@ extern "C" { | |||||||
|  |  | ||||||
|     typedef bool (*llama_progress_callback)(float progress, void * user_data); |     typedef bool (*llama_progress_callback)(float progress, void * user_data); | ||||||
|  |  | ||||||
|     struct llama_batch; |     // Input data for llama_decode | ||||||
|  |     // | ||||||
|  |     // WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead | ||||||
|  |     // | ||||||
|  |     // A llama_batch object can contain input about one or many sequences | ||||||
|  |     // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | ||||||
|  |     // | ||||||
|  |     // - token  : the token ids of the input (used when embd is NULL) | ||||||
|  |     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | ||||||
|  |     // - pos    : the positions of the respective token in the sequence | ||||||
|  |     //            (if set to NULL, the token position will be tracked automatically by llama_decode) | ||||||
|  |     // - seq_id : the sequence to which the respective token belongs | ||||||
|  |     //            (if set to NULL, the sequence ID will be assumed to be 0) | ||||||
|  |     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | ||||||
|  |     //            (if set to NULL, only the logits for last token will be returned) | ||||||
|  |     // | ||||||
|  |     typedef struct llama_batch { | ||||||
|  |         int32_t n_tokens; | ||||||
|  |  | ||||||
|     struct llama_batch_token_info { |         llama_token  *  token; | ||||||
|  |         float        *  embd; | ||||||
|  |         llama_pos    *  pos; | ||||||
|  |         int32_t      *  n_seq_id; | ||||||
|  |         llama_seq_id ** seq_id; | ||||||
|  |         int8_t       *  logits; // TODO: rename this to "output" | ||||||
|  |     } llama_batch; | ||||||
|  |  | ||||||
|  |     // Input data for llama_decode / llama_encode | ||||||
|  |     // It can contain text tokens and embeddings for one or many sequences | ||||||
|  |     struct llama_batch_ext; | ||||||
|  |  | ||||||
|  |     struct llama_batch_ext_token_info { | ||||||
|         llama_token    token; |         llama_token    token; | ||||||
|         llama_pos      pos; |         llama_pos      pos; | ||||||
|         int32_t        n_seq_id; |         int32_t        n_seq_id; | ||||||
| @@ -815,9 +844,9 @@ extern "C" { | |||||||
|     // |     // | ||||||
|     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it |     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it | ||||||
|     // |     // | ||||||
|     LLAMA_API struct llama_batch * llama_batch_get_one( |     DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( | ||||||
|                   llama_token * tokens, |                   llama_token * tokens, | ||||||
|                       int32_t   n_tokens); |                       int32_t   n_tokens), "use llama_batch_ext API instead"); | ||||||
|  |  | ||||||
|     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens |     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens | ||||||
|     // Each token can be assigned up to n_seq_max sequence ids |     // Each token can be assigned up to n_seq_max sequence ids | ||||||
| @@ -826,30 +855,47 @@ extern "C" { | |||||||
|     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token |     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | ||||||
|     // The rest of the llama_batch members are allocated with size n_tokens |     // The rest of the llama_batch members are allocated with size n_tokens | ||||||
|     // All members are left uninitialized |     // All members are left uninitialized | ||||||
|     // LLAMA_API struct llama_batch llama_batch_init( |     DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( | ||||||
|     //         int32_t n_tokens, |                     int32_t n_tokens, | ||||||
|     //         int32_t embd, |                     int32_t embd, | ||||||
|     //         int32_t n_seq_max); |                     int32_t n_seq_max), "use llama_batch_ext API instead"); | ||||||
|  |  | ||||||
|  |     // Frees a batch of tokens allocated with llama_batch_init() | ||||||
|  |     DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), | ||||||
|  |             "use llama_batch_ext API instead"); | ||||||
|  |  | ||||||
|     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens |     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens | ||||||
|     // Each token can be assigned up to n_seq_max sequence ids |     // Each token can be assigned up to n_seq_max sequence ids | ||||||
|     // The batch has to be freed with llama_batch_free() |     // The batch has to be freed with llama_batch_ext_free() | ||||||
|     LLAMA_API struct llama_batch * llama_batch_init( |     LLAMA_API struct llama_batch_ext * llama_batch_ext_init( | ||||||
|             int32_t n_tokens, |             int32_t n_tokens, | ||||||
|             int32_t n_seq_max); |             int32_t n_seq_max); | ||||||
|  |  | ||||||
|  |     // Same with llama_batch_init, but initializes the batch with the provided text tokens | ||||||
|  |     // First token will be at position pos0 | ||||||
|  |     // The sequence ID will be fixed to seq_id | ||||||
|  |     // The batch has to be freed with llama_batch_ext_free() | ||||||
|  |     LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( | ||||||
|  |             llama_token * tokens, | ||||||
|  |                 int32_t   n_tokens, | ||||||
|  |                 int32_t   pos0, | ||||||
|  |                 int32_t   seq_id); | ||||||
|  |  | ||||||
|     // Same with llama_batch_init, but initializes the batch with the provided raw embeddings |     // Same with llama_batch_init, but initializes the batch with the provided raw embeddings | ||||||
|     LLAMA_API struct llama_batch * llama_batch_init_from_embd( |     // First token will be at position pos0 | ||||||
|  |     // The sequence ID will be fixed to seq_id | ||||||
|  |     // The batch has to be freed with llama_batch_ext_free() | ||||||
|  |     LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( | ||||||
|               float * embd, |               float * embd, | ||||||
|             size_t    n_embd, |             size_t    n_embd, | ||||||
|             int32_t   pos0, |             int32_t   pos0, | ||||||
|             int32_t   seq_id); |             int32_t   seq_id); | ||||||
|  |  | ||||||
|     // Get the number of tokens in the batch |     // Get the number of tokens in the batch | ||||||
|     LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch); |     LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     LLAMA_API struct llama_batch_token_info llama_batch_get_token_info( |     LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info( | ||||||
|             struct llama_batch * batch, |         struct llama_batch_ext * batch, | ||||||
|                        int32_t   i); |                        int32_t   i); | ||||||
|  |  | ||||||
|     // Add text tokens to the batch |     // Add text tokens to the batch | ||||||
| @@ -857,8 +903,8 @@ extern "C" { | |||||||
|     //  0 : success |     //  0 : success | ||||||
|     // -1 : not enough space in the batch |     // -1 : not enough space in the batch | ||||||
|     // -2 : embd is already set, cannot add text tokens |     // -2 : embd is already set, cannot add text tokens | ||||||
|     LLAMA_API int32_t llama_batch_add_text_token( |     LLAMA_API int32_t llama_batch_ext_add_text_token( | ||||||
|             struct llama_batch * batch, |         struct llama_batch_ext * batch, | ||||||
|                    llama_token   token, |                    llama_token   token, | ||||||
|                      llama_pos   pos, |                      llama_pos   pos, | ||||||
|             const llama_seq_id * seq_ids, |             const llama_seq_id * seq_ids, | ||||||
| @@ -868,43 +914,50 @@ extern "C" { | |||||||
|     // Set logits for the token in the ith sequence |     // Set logits for the token in the ith sequence | ||||||
|     // If pos == -1, logits will be set for the all tokens |     // If pos == -1, logits will be set for the all tokens | ||||||
|     // Returns -1 if the token is not in the batch |     // Returns -1 if the token is not in the batch | ||||||
|     LLAMA_API int32_t llama_batch_set_logits( |     LLAMA_API int32_t llama_batch_ext_set_logits( | ||||||
|             struct llama_batch * batch, |         struct llama_batch_ext * batch, | ||||||
|                      llama_pos   pos, |                      llama_pos   pos, | ||||||
|                   llama_seq_id   seq_id); |                   llama_seq_id   seq_id); | ||||||
|  |  | ||||||
|     // Set logits for the last added token |     // Set logits for the last added token | ||||||
|     // Returns -1 if there is no tokens in the batch |     // Returns -1 if there is no tokens in the batch | ||||||
|     LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch); |     LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     // Get a "view" from a number of tokens offset |     // Get a "view" from a number of tokens offset | ||||||
|     // Return returned batch must be freed with llama_batch_free() |     // Return returned batch must be freed with llama_batch_free() | ||||||
|     LLAMA_API struct llama_batch * llama_batch_get_view( |     LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view( | ||||||
|             struct llama_batch * batch, |         struct llama_batch_ext * batch, | ||||||
|                        int32_t   offset, |                        int32_t   offset, | ||||||
|                        int32_t   n_tokens); |                        int32_t   n_tokens); | ||||||
|  |  | ||||||
|     // Remove everything from the batch |     // Remove everything from the batch | ||||||
|     LLAMA_API void llama_batch_clear(struct llama_batch * batch); |     LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     // Frees a batch of tokens allocated with llama_batch_init() |     // Frees a batch of tokens allocated with llama_batch_ext_init() | ||||||
|     LLAMA_API void llama_batch_free(struct llama_batch * batch); |     // If this is a view, the original batch is not freed | ||||||
|  |     LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     // Processes a batch of tokens with the ecoder part of the encoder-decoder model. |     // Processes a batch of tokens with the ecoder part of the encoder-decoder model. | ||||||
|     // Stores the encoder output internally for later use by the decoder cross-attention layers. |     // Stores the encoder output internally for later use by the decoder cross-attention layers. | ||||||
|     //   0 - success |     //   0 - success | ||||||
|     // < 0 - error. the KV cache state is restored to the state before this call |     // < 0 - error. the KV cache state is restored to the state before this call | ||||||
|     LLAMA_API int32_t llama_encode( |     DEPRECATED(LLAMA_API int32_t llama_encode( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|               struct llama_batch * batch); |               struct llama_batch   batch), "use llama_batch_ext API instead"); | ||||||
|  |     LLAMA_API int32_t llama_text_encode( | ||||||
|  |             struct llama_context * ctx, | ||||||
|  |           struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     // Positive return values does not mean a fatal error, but rather a warning. |     // Positive return values does not mean a fatal error, but rather a warning. | ||||||
|     //   0 - success |     //   0 - success | ||||||
|     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) |     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | ||||||
|     // < 0 - error. the KV cache state is restored to the state before this call |     // < 0 - error. the KV cache state is restored to the state before this call | ||||||
|     LLAMA_API int32_t llama_decode( |     DEPRECATED(LLAMA_API int32_t llama_decode( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|               struct llama_batch * batch); |               struct llama_batch batch), "use llama_batch_ext API instead"); | ||||||
|  |     LLAMA_API int32_t llama_text_decode( | ||||||
|  |             struct llama_context * ctx, | ||||||
|  |           struct llama_batch_ext * batch); | ||||||
|  |  | ||||||
|     // Set the number of threads used for decoding |     // Set the number of threads used for decoding | ||||||
|     // n_threads is the number of threads used for generation (single token) |     // n_threads is the number of threads used for generation (single token) | ||||||
|   | |||||||
| @@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { | |||||||
|     return ubatch; |     return ubatch; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { | void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) { | ||||||
|     GGML_ASSERT(batch.n_tokens >= 0); |     GGML_ASSERT(batch.n_tokens >= 0); | ||||||
|     this->batch = &batch; |     this->batch = &batch; | ||||||
|     this->n_embd = n_embd; |     this->n_embd = n_embd; | ||||||
| @@ -273,49 +273,61 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim | |||||||
|             ); |             ); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { | llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) { | ||||||
|     batch = in_batch; |     batch = new llama_batch_ext{ | ||||||
|     GGML_ASSERT(batch.n_tokens > 0); |         /*n_tokens       =*/ in_batch.n_tokens, | ||||||
|     if (!batch.pos) { |         /*max_tokens     =*/ in_batch.n_tokens, | ||||||
|         pos.resize(batch.n_tokens); |         /*is_view        =*/ false, | ||||||
|         for (int32_t i = 0; i < batch.n_tokens; i++) { |         /*tokens         =*/ in_batch.token, | ||||||
|  |         /*embd           =*/ in_batch.embd, | ||||||
|  |         /*pos            =*/ in_batch.pos, | ||||||
|  |         /*n_seq_id       =*/ in_batch.n_seq_id, | ||||||
|  |         /*seq_id         =*/ in_batch.seq_id, | ||||||
|  |         /*logits         =*/ in_batch.logits, | ||||||
|  |     }; | ||||||
|  |     GGML_ASSERT(batch->n_tokens > 0); | ||||||
|  |     if (!in_batch.pos) { | ||||||
|  |         pos.resize(batch->n_tokens); | ||||||
|  |         for (int32_t i = 0; i < batch->n_tokens; i++) { | ||||||
|             pos[i] = i + p0; |             pos[i] = i + p0; | ||||||
|         } |         } | ||||||
|         batch.pos = pos.data(); |         batch->pos = pos.data(); | ||||||
|     } |     } | ||||||
|     if (!batch.n_seq_id) { |     if (!batch->n_seq_id) { | ||||||
|         n_seq_id.resize(batch.n_tokens); |         n_seq_id.resize(batch->n_tokens); | ||||||
|         for (int32_t i = 0; i < batch.n_tokens; i++) { |         for (int32_t i = 0; i < batch->n_tokens; i++) { | ||||||
|             n_seq_id[i] = seq_id_0.size(); |             n_seq_id[i] = seq_id_0.size(); | ||||||
|         } |         } | ||||||
|         batch.n_seq_id = n_seq_id.data(); |         batch->n_seq_id = n_seq_id.data(); | ||||||
|     } |     } | ||||||
|     if (!batch.seq_id) { |     if (!batch->seq_id) { | ||||||
|         seq_id.resize(batch.n_tokens + 1); |         seq_id.resize(batch->n_tokens + 1); | ||||||
|         seq_id[batch.n_tokens] = NULL; |         seq_id[batch->n_tokens] = NULL; | ||||||
|         for (int32_t i = 0; i < batch.n_tokens; i++) { |         for (int32_t i = 0; i < batch->n_tokens; i++) { | ||||||
|             seq_id[i] = seq_id_0.data(); |             seq_id[i] = seq_id_0.data(); | ||||||
|         } |         } | ||||||
|         batch.seq_id = seq_id.data(); |         batch->seq_id = seq_id.data(); | ||||||
|     } |     } | ||||||
|     if (!batch.logits) { |     if (!batch->logits) { | ||||||
|         logits.resize(batch.n_tokens); |         logits.resize(batch->n_tokens); | ||||||
|         logits[logits.size() - 1] = true; |         logits[logits.size() - 1] = true; | ||||||
|         batch.logits = logits.data(); |         batch->logits = logits.data(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | llama_batch_allocr::~llama_batch_allocr() { | ||||||
|  |     delete batch; | ||||||
|  | } | ||||||
|  |  | ||||||
| // | // | ||||||
| // interface implementation | // interface implementation | ||||||
| // | // | ||||||
|  |  | ||||||
| struct llama_batch * llama_batch_get_one( | struct llama_batch llama_batch_get_one( | ||||||
|             llama_token * tokens, |             llama_token * tokens, | ||||||
|                 int32_t   n_tokens) { |                 int32_t   n_tokens) { | ||||||
|     return new llama_batch{ |     return llama_batch{ | ||||||
|         /*n_tokens       =*/ n_tokens, |         /*n_tokens       =*/ n_tokens, | ||||||
|         /*max_tokens     =*/ n_tokens, |  | ||||||
|         /*is_view        =*/ false, |  | ||||||
|         /*tokens         =*/ tokens, |         /*tokens         =*/ tokens, | ||||||
|         /*embd           =*/ nullptr, |         /*embd           =*/ nullptr, | ||||||
|         /*pos            =*/ nullptr, |         /*pos            =*/ nullptr, | ||||||
| @@ -325,8 +337,20 @@ struct llama_batch * llama_batch_get_one( | |||||||
|     }; |     }; | ||||||
| } | } | ||||||
|  |  | ||||||
| static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { | struct llama_batch_ext * llama_batch_ext_init_from_text( | ||||||
|     llama_batch * batch = new llama_batch{ |             llama_token * tokens, | ||||||
|  |                 int32_t   n_tokens, | ||||||
|  |                 int32_t   pos0, | ||||||
|  |                 int32_t   seq_id) { | ||||||
|  |     llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); | ||||||
|  |     for (int32_t i = 0; i < n_tokens; i++) { | ||||||
|  |         llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false); | ||||||
|  |     } | ||||||
|  |     return batch; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { | ||||||
|  |     llama_batch_ext * batch = new llama_batch_ext{ | ||||||
|         /*n_tokens       =*/ 0, |         /*n_tokens       =*/ 0, | ||||||
|         /*max_tokens     =*/ n_tokens_alloc, |         /*max_tokens     =*/ n_tokens_alloc, | ||||||
|         /*is_view        =*/ false, |         /*is_view        =*/ false, | ||||||
| @@ -357,16 +381,16 @@ static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_ | |||||||
|     return batch; |     return batch; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) { | struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) { | ||||||
|     return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max); |     return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max); | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_batch * llama_batch_init_from_embd( | struct llama_batch_ext * llama_batch_ext_init_from_embd( | ||||||
|               float * embd, |               float * embd, | ||||||
|             size_t    n_embd, |             size_t    n_embd, | ||||||
|             int32_t   pos0, |             int32_t   pos0, | ||||||
|             int32_t   seq_id) { |             int32_t   seq_id) { | ||||||
|     struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); |     struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1); | ||||||
|     memcpy(batch->embd, embd, n_embd * sizeof(float)); |     memcpy(batch->embd, embd, n_embd * sizeof(float)); | ||||||
|     for (size_t i = 0; i < n_embd; i++) { |     for (size_t i = 0; i < n_embd; i++) { | ||||||
|         batch->pos     [i] = pos0 + i; |         batch->pos     [i] = pos0 + i; | ||||||
| @@ -376,12 +400,12 @@ struct llama_batch * llama_batch_init_from_embd( | |||||||
|     return batch; |     return batch; | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) { | int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { | ||||||
|     return batch->n_tokens; |     return batch->n_tokens; | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_batch_add_text_token( | int32_t llama_batch_ext_add_text_token( | ||||||
|         struct llama_batch * batch, |     struct llama_batch_ext * batch, | ||||||
|                llama_token   token, |                llama_token   token, | ||||||
|                  llama_pos   pos, |                  llama_pos   pos, | ||||||
|         const llama_seq_id * seq_ids, |         const llama_seq_id * seq_ids, | ||||||
| @@ -404,8 +428,8 @@ int32_t llama_batch_add_text_token( | |||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_batch_set_logits( | int32_t llama_batch_ext_set_logits( | ||||||
|         struct llama_batch * batch, |     struct llama_batch_ext * batch, | ||||||
|                  llama_pos   pos, |                  llama_pos   pos, | ||||||
|               llama_seq_id   seq_id) { |               llama_seq_id   seq_id) { | ||||||
|     for (int32_t i = 0; i < batch->n_tokens; i++) { |     for (int32_t i = 0; i < batch->n_tokens; i++) { | ||||||
| @@ -423,7 +447,7 @@ int32_t llama_batch_set_logits( | |||||||
|     return -1; // not found |     return -1; // not found | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_batch_set_logits_last(struct llama_batch * batch) { | int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { | ||||||
|     if (batch->n_tokens == 0) { |     if (batch->n_tokens == 0) { | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
| @@ -431,18 +455,18 @@ int32_t llama_batch_set_logits_last(struct llama_batch * batch) { | |||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_batch_clear(struct llama_batch * batch) { | void llama_batch_ext_clear(struct llama_batch_ext * batch) { | ||||||
|     batch->n_tokens = 0; |     batch->n_tokens = 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_batch * llama_batch_get_view( | struct llama_batch_ext * llama_batch_ext_get_view( | ||||||
|         struct llama_batch * batch, |     struct llama_batch_ext * batch, | ||||||
|                    int32_t   offset, |                    int32_t   offset, | ||||||
|                    int32_t   n_tokens) { |                    int32_t   n_tokens) { | ||||||
|     if (batch->embd) { |     if (batch->embd) { | ||||||
|         return nullptr; // not yet supported |         return nullptr; // not yet supported | ||||||
|     } |     } | ||||||
|     llama_batch * batch_view = new llama_batch{ |     llama_batch_ext * batch_view = new llama_batch_ext{ | ||||||
|         /*n_tokens       =*/ n_tokens, |         /*n_tokens       =*/ n_tokens, | ||||||
|         /*max_tokens     =*/ n_tokens, |         /*max_tokens     =*/ n_tokens, | ||||||
|         /*is_view        =*/ true, |         /*is_view        =*/ true, | ||||||
| @@ -456,11 +480,11 @@ struct llama_batch * llama_batch_get_view( | |||||||
|     return batch_view; |     return batch_view; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_batch_token_info llama_batch_get_token_info( | struct llama_batch_ext_token_info llama_batch_ext_get_token_info( | ||||||
|         struct llama_batch * batch, |     struct llama_batch_ext * batch, | ||||||
|                    int32_t   i) { |                    int32_t   i) { | ||||||
|     GGML_ASSERT(i >= 0 && i < batch->n_tokens); |     GGML_ASSERT(i >= 0 && i < batch->n_tokens); | ||||||
|     return llama_batch_token_info{ |     return llama_batch_ext_token_info{ | ||||||
|         /*token    =*/ batch->token   [i], |         /*token    =*/ batch->token   [i], | ||||||
|         /*pos      =*/ batch->pos     [i], |         /*pos      =*/ batch->pos     [i], | ||||||
|         /*n_seq_id =*/ batch->n_seq_id[i], |         /*n_seq_id =*/ batch->n_seq_id[i], | ||||||
| @@ -469,7 +493,7 @@ struct llama_batch_token_info llama_batch_get_token_info( | |||||||
|     }; |     }; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_batch_free(struct llama_batch * batch) { | void llama_batch_ext_free(struct llama_batch_ext * batch) { | ||||||
|     // do not free the members if it's a view |     // do not free the members if it's a view | ||||||
|     if (!batch->is_view) { |     if (!batch->is_view) { | ||||||
|         if (batch->token)    free(batch->token); |         if (batch->token)    free(batch->token); | ||||||
|   | |||||||
| @@ -5,8 +5,8 @@ | |||||||
| #include <array> | #include <array> | ||||||
| #include <vector> | #include <vector> | ||||||
|  |  | ||||||
| // Input data for llama_decode | // Input data for llama_decode / llama_encode | ||||||
| // A llama_batch object can contain input about one or many sequences | // A llama_batch_ext object can contain input about one or many sequences | ||||||
| // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | ||||||
| // | // | ||||||
| // - token  : the token ids of the input (used when embd is NULL) | // - token  : the token ids of the input (used when embd is NULL) | ||||||
| @@ -18,7 +18,7 @@ | |||||||
| // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | ||||||
| //            (if set to NULL, only the logits for last token will be returned) | //            (if set to NULL, only the logits for last token will be returned) | ||||||
| // | // | ||||||
| struct llama_batch { | struct llama_batch_ext { | ||||||
|     int32_t n_tokens; |     int32_t n_tokens; | ||||||
|     int32_t max_tokens; |     int32_t max_tokens; | ||||||
|     bool is_view; |     bool is_view; | ||||||
| @@ -73,7 +73,7 @@ struct llama_sbatch { | |||||||
|     std::vector<size_t> out_ids; |     std::vector<size_t> out_ids; | ||||||
|     std::vector<llama_sbatch_seq> seq; |     std::vector<llama_sbatch_seq> seq; | ||||||
|  |  | ||||||
|     const llama_batch * batch = nullptr; |     const llama_batch_ext * batch = nullptr; | ||||||
|  |  | ||||||
|     // buffers for the ubatch |     // buffers for the ubatch | ||||||
|     std::vector<llama_token>    ubatch_token; |     std::vector<llama_token>    ubatch_token; | ||||||
| @@ -96,12 +96,12 @@ struct llama_sbatch { | |||||||
|     // sequence-wise split |     // sequence-wise split | ||||||
|     llama_ubatch split_seq(size_t n_ubatch); |     llama_ubatch split_seq(size_t n_ubatch); | ||||||
|  |  | ||||||
|     void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); |     void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // temporary allocate memory for the input batch if needed | // temporary allocate memory for the input batch if needed | ||||||
| struct llama_batch_allocr { | struct llama_batch_allocr { | ||||||
|     struct llama_batch batch; |     struct llama_batch_ext * batch; | ||||||
|  |  | ||||||
|     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id |     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id | ||||||
|     std::vector<llama_pos>      pos; |     std::vector<llama_pos>      pos; | ||||||
| @@ -110,5 +110,7 @@ struct llama_batch_allocr { | |||||||
|     std::vector<int8_t>         logits; |     std::vector<int8_t>         logits; | ||||||
|  |  | ||||||
|     // optionally fulfill the batch returned by llama_batch_get_one |     // optionally fulfill the batch returned by llama_batch_get_one | ||||||
|     llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); |     llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0); | ||||||
|  |  | ||||||
|  |     ~llama_batch_allocr(); | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -8445,7 +8445,7 @@ static enum ggml_status llama_graph_compute( | |||||||
|  |  | ||||||
| static int llama_prepare_sbatch( | static int llama_prepare_sbatch( | ||||||
|         llama_context     & lctx, |         llama_context     & lctx, | ||||||
|         const llama_batch & batch, |     const llama_batch_ext & batch, | ||||||
|         uint32_t          & n_outputs) { |         uint32_t          & n_outputs) { | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
| @@ -8585,7 +8585,7 @@ static int llama_prepare_ubatch( | |||||||
| // | // | ||||||
| static int llama_decode_impl( | static int llama_decode_impl( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch   inp_batch) { |        llama_batch_ext & inp_batch) { | ||||||
|  |  | ||||||
|     lctx.is_encoding = false; |     lctx.is_encoding = false; | ||||||
|  |  | ||||||
| @@ -8594,10 +8594,6 @@ static int llama_decode_impl( | |||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // temporarily allocate memory for the input batch if needed |  | ||||||
|     llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); |  | ||||||
|     const llama_batch & batch = batch_allocr.batch; |  | ||||||
|  |  | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & vocab   = model.vocab; |     const auto & vocab   = model.vocab; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
| @@ -8616,7 +8612,7 @@ static int llama_decode_impl( | |||||||
|     uint32_t n_outputs_prev = 0; |     uint32_t n_outputs_prev = 0; | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         const int ret = llama_prepare_sbatch(lctx, batch, n_outputs); |         const int ret = llama_prepare_sbatch(lctx, inp_batch, n_outputs); | ||||||
|         if (ret != 0) { |         if (ret != 0) { | ||||||
|             return ret; |             return ret; | ||||||
|         } |         } | ||||||
| @@ -8625,7 +8621,7 @@ static int llama_decode_impl( | |||||||
|     while (lctx.sbatch.n_tokens > 0) { |     while (lctx.sbatch.n_tokens > 0) { | ||||||
|         llama_ubatch ubatch; |         llama_ubatch ubatch; | ||||||
|         { |         { | ||||||
|             const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens); |             const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, inp_batch.n_tokens); | ||||||
|             if (ret != 0) { |             if (ret != 0) { | ||||||
|                 return ret; |                 return ret; | ||||||
|             } |             } | ||||||
| @@ -8832,7 +8828,7 @@ static int llama_decode_impl( | |||||||
| // | // | ||||||
| static int llama_encode_impl( | static int llama_encode_impl( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch   inp_batch) { |        llama_batch_ext & inp_batch) { | ||||||
|  |  | ||||||
|     lctx.is_encoding = true; |     lctx.is_encoding = true; | ||||||
|  |  | ||||||
| @@ -8841,22 +8837,18 @@ static int llama_encode_impl( | |||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // temporary allocate memory for the input batch if needed |     const uint32_t n_tokens = inp_batch.n_tokens; | ||||||
|     llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); |  | ||||||
|  |  | ||||||
|     const llama_batch & batch = batch_allocr.batch; |  | ||||||
|     const uint32_t n_tokens = batch.n_tokens; |  | ||||||
|  |  | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|     const auto & cparams = lctx.cparams; |     const auto & cparams = lctx.cparams; | ||||||
|  |  | ||||||
|     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT |     GGML_ASSERT((!inp_batch.token && inp_batch.embd) || (inp_batch.token && !inp_batch.embd)); // NOLINT | ||||||
|  |  | ||||||
|     if (batch.token) { |     if (inp_batch.token) { | ||||||
|         for (uint32_t i = 0; i < n_tokens; ++i) { |         for (uint32_t i = 0; i < n_tokens; ++i) { | ||||||
|             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { |             if (inp_batch.token[i] < 0 || (uint32_t) inp_batch.token[i] >= model.vocab.n_tokens()) { | ||||||
|                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); |                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, inp_batch.token[i]); | ||||||
|                 return -1; |                 return -1; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -8873,7 +8865,7 @@ static int llama_encode_impl( | |||||||
|  |  | ||||||
|     const int64_t n_embd = hparams.n_embd; |     const int64_t n_embd = hparams.n_embd; | ||||||
|  |  | ||||||
|     lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); |     lctx.sbatch.from_batch(inp_batch, n_embd, /* simple_split */ true, /* logits_all */ true); | ||||||
|  |  | ||||||
|     const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); |     const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); | ||||||
|  |  | ||||||
| @@ -9976,9 +9968,32 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { | |||||||
|  |  | ||||||
| /// | /// | ||||||
|  |  | ||||||
|  |  | ||||||
|  | // DEPRECATED | ||||||
| int32_t llama_encode( | int32_t llama_encode( | ||||||
|         struct llama_context * ctx, |         struct llama_context * ctx, | ||||||
|           struct llama_batch * batch) { |           struct llama_batch batch) { | ||||||
|  |     // temporarily allocate memory for the input batch if needed | ||||||
|  |     // also convert llama_batch to llama_batch_ext | ||||||
|  |     llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); | ||||||
|  |     llama_batch_ext * batch_ext = batch_allocr.batch; | ||||||
|  |     return llama_text_encode(ctx, batch_ext); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DEPRECATED | ||||||
|  | int32_t llama_decode( | ||||||
|  |         struct llama_context * ctx, | ||||||
|  |           struct llama_batch batch) { | ||||||
|  |     // temporarily allocate memory for the input batch if needed | ||||||
|  |     // also convert llama_batch to llama_batch_ext | ||||||
|  |     llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); | ||||||
|  |     llama_batch_ext * batch_ext = batch_allocr.batch; | ||||||
|  |     return llama_text_decode(ctx, batch_ext); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int32_t llama_text_encode( | ||||||
|  |         struct llama_context * ctx, | ||||||
|  |       struct llama_batch_ext * batch) { | ||||||
|     const int ret = llama_encode_impl(*ctx, *batch); |     const int ret = llama_encode_impl(*ctx, *batch); | ||||||
|     if (ret != 0) { |     if (ret != 0) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); | ||||||
| @@ -9987,9 +10002,9 @@ int32_t llama_encode( | |||||||
|     return ret; |     return ret; | ||||||
| } | } | ||||||
|  |  | ||||||
| int32_t llama_decode( | int32_t llama_text_decode( | ||||||
|         struct llama_context * ctx, |         struct llama_context * ctx, | ||||||
|           struct llama_batch * batch) { |       struct llama_batch_ext * batch) { | ||||||
|     const int ret = llama_decode_impl(*ctx, *batch); |     const int ret = llama_decode_impl(*ctx, *batch); | ||||||
|     if (ret != 0) { |     if (ret != 0) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen