mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	first proposal for private llama_batch
This commit is contained in:
		| @@ -231,29 +231,7 @@ extern "C" { | |||||||
|  |  | ||||||
|     typedef bool (*llama_progress_callback)(float progress, void * user_data); |     typedef bool (*llama_progress_callback)(float progress, void * user_data); | ||||||
|  |  | ||||||
|     // Input data for llama_decode |     struct llama_batch; | ||||||
|     // A llama_batch object can contain input about one or many sequences |  | ||||||
|     // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens |  | ||||||
|     // |  | ||||||
|     // - token  : the token ids of the input (used when embd is NULL) |  | ||||||
|     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) |  | ||||||
|     // - pos    : the positions of the respective token in the sequence |  | ||||||
|     //            (if set to NULL, the token position will be tracked automatically by llama_decode) |  | ||||||
|     // - seq_id : the sequence to which the respective token belongs |  | ||||||
|     //            (if set to NULL, the sequence ID will be assumed to be 0) |  | ||||||
|     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output |  | ||||||
|     //            (if set to NULL, only the logits for last token will be returned) |  | ||||||
|     // |  | ||||||
|     typedef struct llama_batch { |  | ||||||
|         int32_t n_tokens; |  | ||||||
|  |  | ||||||
|         llama_token  *  token; |  | ||||||
|         float        *  embd; |  | ||||||
|         llama_pos    *  pos; |  | ||||||
|         int32_t      *  n_seq_id; |  | ||||||
|         llama_seq_id ** seq_id; |  | ||||||
|         int8_t       *  logits; // TODO: rename this to "output" |  | ||||||
|     } llama_batch; |  | ||||||
|  |  | ||||||
|     enum llama_model_kv_override_type { |     enum llama_model_kv_override_type { | ||||||
|         LLAMA_KV_OVERRIDE_TYPE_INT, |         LLAMA_KV_OVERRIDE_TYPE_INT, | ||||||
| @@ -829,7 +807,7 @@ extern "C" { | |||||||
|     // |     // | ||||||
|     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it |     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it | ||||||
|     // |     // | ||||||
|     LLAMA_API struct llama_batch llama_batch_get_one( |     LLAMA_API struct llama_batch * llama_batch_get_one( | ||||||
|                   llama_token * tokens, |                   llama_token * tokens, | ||||||
|                       int32_t   n_tokens); |                       int32_t   n_tokens); | ||||||
|  |  | ||||||
| @@ -840,13 +818,59 @@ extern "C" { | |||||||
|     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token |     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | ||||||
|     // The rest of the llama_batch members are allocated with size n_tokens |     // The rest of the llama_batch members are allocated with size n_tokens | ||||||
|     // All members are left uninitialized |     // All members are left uninitialized | ||||||
|     LLAMA_API struct llama_batch llama_batch_init( |     // LLAMA_API struct llama_batch llama_batch_init( | ||||||
|  |     //         int32_t n_tokens, | ||||||
|  |     //         int32_t embd, | ||||||
|  |     //         int32_t n_seq_max); | ||||||
|  |  | ||||||
|  |     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens | ||||||
|  |     // Each token can be assigned up to n_seq_max sequence ids | ||||||
|  |     // The batch has to be freed with llama_batch_free() | ||||||
|  |     LLAMA_API struct llama_batch * llama_batch_init( | ||||||
|             int32_t n_tokens, |             int32_t n_tokens, | ||||||
|             int32_t embd, |  | ||||||
|             int32_t n_seq_max); |             int32_t n_seq_max); | ||||||
|  |  | ||||||
|  |     // Same with llama_batch_init, but initializes the batch with the provided raw embeddings | ||||||
|  |     LLAMA_API struct llama_batch * llama_batch_init_from_embd( | ||||||
|  |               float * embd, | ||||||
|  |             size_t    n_embd, | ||||||
|  |             int32_t   pos0, | ||||||
|  |             int32_t   seq_id); | ||||||
|  |  | ||||||
|  |     // Add text tokens to the batch | ||||||
|  |     // First token in the list starts at position pos0 | ||||||
|  |     // Return values: | ||||||
|  |     //  0 : success | ||||||
|  |     // -1 : not enough space in the batch | ||||||
|  |     // -2 : embd is already set, cannot add text tokens | ||||||
|  |     LLAMA_API int32_t llama_batch_add_text( | ||||||
|  |             struct llama_batch * batch, | ||||||
|  |                    llama_token * tokens, | ||||||
|  |                        size_t    n_tokens, | ||||||
|  |                        int32_t   pos0, | ||||||
|  |                        int32_t   seq_id); | ||||||
|  |  | ||||||
|  |     // Same as llama_batch_add_text, but accepts multiple sequences | ||||||
|  |     LLAMA_API int32_t llama_batch_add_text( | ||||||
|  |             struct llama_batch * batch, | ||||||
|  |                    llama_token * tokens, | ||||||
|  |                        size_t    n_tokens, | ||||||
|  |                        int32_t   pos0, | ||||||
|  |                        int32_t * seq_ids, | ||||||
|  |                        size_t    n_seq_ids); | ||||||
|  |  | ||||||
|  |     // Set logits for the token in the ith sequence | ||||||
|  |     // If pos == -1, logits will be set for the all tokens | ||||||
|  |     LLAMA_API int32_t llama_batch_set_logits( | ||||||
|  |             struct llama_batch * batch, | ||||||
|  |                        int32_t   pos, | ||||||
|  |                        int32_t   seq_id); | ||||||
|  |  | ||||||
|  |     // Remove everything from the batch | ||||||
|  |     LLAMA_API void llama_batch_clear(struct llama_batch * batch); | ||||||
|  |  | ||||||
|     // Frees a batch of tokens allocated with llama_batch_init() |     // Frees a batch of tokens allocated with llama_batch_init() | ||||||
|     LLAMA_API void llama_batch_free(struct llama_batch batch); |     LLAMA_API void llama_batch_free(struct llama_batch * batch); | ||||||
|  |  | ||||||
|     // Processes a batch of tokens with the ecoder part of the encoder-decoder model. |     // Processes a batch of tokens with the ecoder part of the encoder-decoder model. | ||||||
|     // Stores the encoder output internally for later use by the decoder cross-attention layers. |     // Stores the encoder output internally for later use by the decoder cross-attention layers. | ||||||
|   | |||||||
| @@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 | |||||||
| // interface implementation | // interface implementation | ||||||
| // | // | ||||||
|  |  | ||||||
| struct llama_batch llama_batch_get_one( | struct llama_batch * llama_batch_get_one( | ||||||
|              llama_token * tokens, |              llama_token * tokens, | ||||||
|                  int32_t   n_tokens) { |                  int32_t   n_tokens) { | ||||||
|     return { |     return new llama_batch{ | ||||||
|         /*n_tokens       =*/ n_tokens, |         /*n_tokens       =*/ n_tokens, | ||||||
|         /*tokens         =*/ tokens, |         /*tokens         =*/ tokens, | ||||||
|         /*embd           =*/ nullptr, |         /*embd           =*/ nullptr, | ||||||
| @@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one( | |||||||
|     }; |     }; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { | static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { | ||||||
|     llama_batch batch = { |     llama_batch * batch = new llama_batch{ | ||||||
|         /*n_tokens       =*/ 0, |         /*n_tokens       =*/ 0, | ||||||
|         /*tokens         =*/ nullptr, |         /*tokens         =*/ nullptr, | ||||||
|         /*embd           =*/ nullptr, |         /*embd           =*/ nullptr, | ||||||
| @@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ | |||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     if (embd) { |     if (embd) { | ||||||
|         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); |         batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); | ||||||
|     } else { |     } else { | ||||||
|         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); |         batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc); |     batch->pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc); | ||||||
|     batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc); |     batch->n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc); | ||||||
|     batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); |     batch->seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); | ||||||
|     for (int i = 0; i < n_tokens_alloc; ++i) { |     for (int i = 0; i < n_tokens_alloc; ++i) { | ||||||
|         batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); |         batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); | ||||||
|     } |     } | ||||||
|     batch.seq_id[n_tokens_alloc] = nullptr; |     batch->seq_id[n_tokens_alloc] = nullptr; | ||||||
|  |  | ||||||
|     batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc); |     batch->logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc); | ||||||
|  |  | ||||||
|     return batch; |     return batch; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_batch_free(struct llama_batch batch) { | struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) { | ||||||
|     if (batch.token)    free(batch.token); |     return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max); | ||||||
|     if (batch.embd)     free(batch.embd); | } | ||||||
|     if (batch.pos)      free(batch.pos); |  | ||||||
|     if (batch.n_seq_id) free(batch.n_seq_id); | struct llama_batch * llama_batch_init_from_embd( | ||||||
|     if (batch.seq_id) { |               float * embd, | ||||||
|         for (int i = 0; batch.seq_id[i] != nullptr; ++i) { |             size_t    n_embd, | ||||||
|             free(batch.seq_id[i]); |             int32_t   pos0, | ||||||
|         } |             int32_t   seq_id) { | ||||||
|         free(batch.seq_id); |     struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); | ||||||
|     } |     memcpy(batch->embd, embd, n_embd * sizeof(float)); | ||||||
|     if (batch.logits)   free(batch.logits); |     for (int32_t i = 0; i < n_embd; i++) { | ||||||
|  |         batch->pos     [i] = pos0 + i; | ||||||
|  |         batch->n_seq_id[i] = 1; | ||||||
|  |         batch->seq_id  [i][0] = seq_id; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int32_t llama_batch_add_text( | ||||||
|  |         struct llama_batch * batch, | ||||||
|  |                llama_token * tokens, | ||||||
|  |                    size_t    n_tokens, | ||||||
|  |                    int32_t   pos0, | ||||||
|  |                    int32_t * seq_ids, | ||||||
|  |                    size_t    n_seq_ids) { | ||||||
|  |     if (batch->n_tokens + n_tokens > batch->n_tokens) { | ||||||
|  |         return -1; | ||||||
|  |     } | ||||||
|  |     if (batch->embd) { | ||||||
|  |         return -2; | ||||||
|  |     } | ||||||
|  |     for (int32_t i = 0; i < n_tokens; i++) { | ||||||
|  |         batch->token   [batch->n_tokens + i] = tokens[i]; | ||||||
|  |         batch->pos     [batch->n_tokens + i] = pos0 + i; | ||||||
|  |         batch->n_seq_id[batch->n_tokens + i] = n_seq_ids; | ||||||
|  |         for (int32_t j = 0; j < n_seq_ids; j++) { | ||||||
|  |             batch->seq_id[batch->n_tokens + i][j] = seq_ids[j]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int32_t llama_batch_add_text( | ||||||
|  |         struct llama_batch * batch, | ||||||
|  |                llama_token * tokens, | ||||||
|  |                    size_t    n_tokens, | ||||||
|  |                    int32_t   pos0, | ||||||
|  |                    int32_t   seq_id) { | ||||||
|  |     std::array<int32_t, 1> seq_ids = { seq_id }; | ||||||
|  |     return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int32_t llama_batch_set_logits( | ||||||
|  |         struct llama_batch * batch, | ||||||
|  |                    int32_t   pos, | ||||||
|  |                    int32_t   seq_id) { | ||||||
|  |     for (int32_t i = 0; i < batch->n_tokens; i++) { | ||||||
|  |         // find the token having seq_id | ||||||
|  |         for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { | ||||||
|  |             if (batch->seq_id[i][j] == seq_id) { | ||||||
|  |                 // found the sequence | ||||||
|  |                 if (pos == -1 || pos == batch->pos[i]) { | ||||||
|  |                     batch->logits[i] = true; | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void llama_batch_clear(struct llama_batch * batch) { | ||||||
|  |     batch->n_tokens = 0; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void llama_batch_free(struct llama_batch * batch) { | ||||||
|  |     if (batch->token)    free(batch->token); | ||||||
|  |     if (batch->embd)     free(batch->embd); | ||||||
|  |     if (batch->pos)      free(batch->pos); | ||||||
|  |     if (batch->n_seq_id) free(batch->n_seq_id); | ||||||
|  |     if (batch->seq_id) { | ||||||
|  |         for (int i = 0; batch->seq_id[i] != nullptr; ++i) { | ||||||
|  |             free(batch->seq_id[i]); | ||||||
|  |         } | ||||||
|  |         free(batch->seq_id); | ||||||
|  |     } | ||||||
|  |     if (batch->logits)   free(batch->logits); | ||||||
|  |     delete batch; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,6 +5,30 @@ | |||||||
| #include <array> | #include <array> | ||||||
| #include <vector> | #include <vector> | ||||||
|  |  | ||||||
|  | // Input data for llama_decode | ||||||
|  | // A llama_batch object can contain input about one or many sequences | ||||||
|  | // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | ||||||
|  | // | ||||||
|  | // - token  : the token ids of the input (used when embd is NULL) | ||||||
|  | // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | ||||||
|  | // - pos    : the positions of the respective token in the sequence | ||||||
|  | //            (if set to NULL, the token position will be tracked automatically by llama_decode) | ||||||
|  | // - seq_id : the sequence to which the respective token belongs | ||||||
|  | //            (if set to NULL, the sequence ID will be assumed to be 0) | ||||||
|  | // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | ||||||
|  | //            (if set to NULL, only the logits for last token will be returned) | ||||||
|  | // | ||||||
|  | struct llama_batch { | ||||||
|  |     int32_t n_tokens; | ||||||
|  |  | ||||||
|  |     llama_token  *  token; | ||||||
|  |     float        *  embd; | ||||||
|  |     llama_pos    *  pos; | ||||||
|  |     int32_t      *  n_seq_id; | ||||||
|  |     llama_seq_id ** seq_id; | ||||||
|  |     int8_t       *  logits; // TODO: rename this to "output" | ||||||
|  | }; | ||||||
|  |  | ||||||
| // very similar to llama_batch, | // very similar to llama_batch, | ||||||
| // but has more metadata about sequences | // but has more metadata about sequences | ||||||
| struct llama_ubatch { | struct llama_ubatch { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen