mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : improve llama_batch API + simplify parallel example
This commit is contained in:
		| @@ -127,11 +127,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     llama_seq_id g_seq_id = 0; | ||||
|  | ||||
|     std::vector<llama_token>  batch_token; | ||||
|     std::vector<llama_pos>    batch_pos; | ||||
|     std::vector<llama_seq_id> batch_seq_id; | ||||
|     std::vector<int8_t>       batch_logits; | ||||
|     std::vector<client *>     batch_clients; | ||||
|     llama_batch batch = llama_batch_init(params.n_batch, 0); | ||||
|  | ||||
|     int32_t n_total_prompt = 0; | ||||
|     int32_t n_total_gen    = 0; | ||||
| @@ -146,24 +142,15 @@ int main(int argc, char ** argv) { | ||||
|     { | ||||
|         LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); | ||||
|  | ||||
|         batch_pos.clear(); | ||||
|         batch_seq_id.clear(); | ||||
|         batch.n_tokens = n_tokens_system; | ||||
|  | ||||
|         for (size_t i = 0; i < n_tokens_system; ++i) { | ||||
|             batch_pos.push_back(i); | ||||
|             batch_seq_id.push_back(0); | ||||
|         for (uint32_t i = 0; i < batch.n_tokens; ++i) { | ||||
|             batch.token[i]  = tokens_system[i]; | ||||
|             batch.pos[i]    = i; | ||||
|             batch.seq_id[i] = 0; | ||||
|             batch.logits[i] = false; | ||||
|         } | ||||
|  | ||||
|         llama_batch batch = { | ||||
|             n_tokens_system, | ||||
|             tokens_system.data(), | ||||
|             nullptr, | ||||
|             batch_pos.data(), | ||||
|             batch_seq_id.data(), | ||||
|             nullptr, | ||||
|             0, 0, 0, // unused | ||||
|         }; | ||||
|  | ||||
|         if (llama_decode(ctx, batch, params.n_threads) != 0) { | ||||
|             LOG_TEE("%s: llama_decode() failed\n", __func__); | ||||
|             return 1; | ||||
| @@ -180,63 +167,72 @@ int main(int argc, char ** argv) { | ||||
|     LOG_TEE("Processing requests ...\n\n"); | ||||
|  | ||||
|     while (true) { | ||||
|         uint32_t n_tokens = 0; | ||||
|  | ||||
|         batch_token.clear(); | ||||
|         batch_pos.clear(); | ||||
|         batch_seq_id.clear(); | ||||
|         batch_logits.clear(); | ||||
|         batch.n_tokens = 0; | ||||
|  | ||||
|         // decode any currently ongoing sequences | ||||
|         for (auto & client : clients) { | ||||
|             if (client.seq_id == -1) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             batch_token.push_back(client.sampled); | ||||
|             batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded); | ||||
|             batch_seq_id.push_back(client.id); | ||||
|             batch_logits.push_back(true); | ||||
|             batch_clients.push_back(&client); | ||||
|             batch.token [batch.n_tokens] = client.sampled; | ||||
|             batch.pos   [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded; | ||||
|             batch.seq_id[batch.n_tokens] = client.id; | ||||
|             batch.logits[batch.n_tokens] = true; | ||||
|  | ||||
|             client.n_decoded += 1; | ||||
|             client.i_batch = batch_token.size() - 1; | ||||
|             client.i_batch = batch.n_tokens; | ||||
|  | ||||
|             batch.n_tokens += 1; | ||||
|         } | ||||
|  | ||||
|         if (batch_token.empty()) { | ||||
|         if (batch.n_tokens == 0) { | ||||
|             // all sequences have ended - clear the entire KV cache | ||||
|             for (int i = 0; i < n_clients; ++i) { | ||||
|                 llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (cont_batching || batch_token.empty()) { | ||||
|         // insert new sequences for decoding | ||||
|         if (cont_batching || batch.n_tokens == 0) { | ||||
|             for (auto & client : clients) { | ||||
|                 if (client.seq_id == -1 && g_seq_id < n_seq) { | ||||
|                     client.seq_id = g_seq_id; | ||||
|  | ||||
|                     client.t_start_prompt = ggml_time_us(); | ||||
|                     client.t_start_gen    = 0; | ||||
|  | ||||
|                     client.input    = k_prompts[rand() % k_prompts.size()]; | ||||
|                     client.prompt   = client.input + "\nAssistant:"; | ||||
|                     client.response = ""; | ||||
|  | ||||
|                     std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); | ||||
|  | ||||
|                     std::vector<llama_token> tokens_prompt; | ||||
|                     tokens_prompt = ::llama_tokenize(ctx, client.prompt, true); | ||||
|  | ||||
|                     for (size_t i = 0; i < tokens_prompt.size(); ++i) { | ||||
|                         batch_token.push_back(tokens_prompt[i]); | ||||
|                         batch_pos.push_back(i + n_tokens_system); | ||||
|                         batch_seq_id.push_back(client.id); | ||||
|                         batch_clients.push_back(&client); | ||||
|                         batch_logits.push_back(false); | ||||
|                         batch.token [batch.n_tokens] = tokens_prompt[i]; | ||||
|                         batch.pos   [batch.n_tokens] = i + n_tokens_system; | ||||
|                         batch.seq_id[batch.n_tokens] = client.id; | ||||
|                         batch.logits[batch.n_tokens] = false; | ||||
|                         batch.n_tokens += 1; | ||||
|                     } | ||||
|  | ||||
|                     // extract the logits only for the last token | ||||
|                     if (batch.n_tokens > 0) { | ||||
|                         batch.logits[batch.n_tokens - 1] = true; | ||||
|                     } | ||||
|                     batch_logits.back() = true; | ||||
|  | ||||
|                     client.n_prompt  = tokens_prompt.size(); | ||||
|                     client.n_decoded = 0; | ||||
|                     client.i_batch   = batch_token.size() - 1; | ||||
|                     client.i_batch   = batch.n_tokens - 1; | ||||
|  | ||||
|                     LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); | ||||
|  | ||||
|                     g_seq_id += 1; | ||||
|  | ||||
|                     // insert new requests one-by-one | ||||
|                     //if (cont_batching) { | ||||
|                     //    break; | ||||
|                     //} | ||||
| @@ -244,34 +240,35 @@ int main(int argc, char ** argv) { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (batch_token.empty()) { | ||||
|         if (batch.n_tokens == 0) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         // process in chunks of params.n_batch | ||||
|         int32_t n_batch = params.n_batch; | ||||
|  | ||||
|         for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) { | ||||
|             n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i)); | ||||
|         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { | ||||
|             const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); | ||||
|  | ||||
|             llama_batch batch = { | ||||
|             llama_batch batch_view = { | ||||
|                 n_tokens, | ||||
|                 batch_token.data() + i, | ||||
|                 batch.token  + i, | ||||
|                 nullptr, | ||||
|                 batch_pos.data() + i, | ||||
|                 batch_seq_id.data() + i, | ||||
|                 batch_logits.data() + i, | ||||
|                 batch.pos    + i, | ||||
|                 batch.seq_id + i, | ||||
|                 batch.logits + i, | ||||
|                 0, 0, 0, // unused | ||||
|             }; | ||||
|  | ||||
|             const int ret = llama_decode(ctx, batch, params.n_threads); | ||||
|             const int ret = llama_decode(ctx, batch_view, params.n_threads); | ||||
|             if (ret != 0) { | ||||
|                 if (n_batch == 1 || ret < 0) { | ||||
|                     LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); | ||||
|                     // if you get here, it means the KV cache is full - try increasing it via the context size | ||||
|                     LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); | ||||
|                     return 1; | ||||
|                 } | ||||
|  | ||||
|                 LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2); | ||||
|                 LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); | ||||
|  | ||||
|                 n_cache_miss += 1; | ||||
|  | ||||
| @@ -357,6 +354,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     llama_print_timings(ctx); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|  | ||||
|     llama_free(ctx); | ||||
|     llama_free_model(model); | ||||
|  | ||||
|   | ||||
| @@ -419,7 +419,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par | ||||
| } | ||||
|  | ||||
| static std::vector<float> hellaswag_evaluate_tokens( | ||||
|     llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread | ||||
|     llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread | ||||
| ) { | ||||
|     std::vector<float> result; | ||||
|     result.reserve(tokens.size() * n_vocab); | ||||
|   | ||||
| @@ -10,10 +10,12 @@ int main(int argc, char ** argv) { | ||||
|     gpt_params params; | ||||
|  | ||||
|     if (argc == 1 || argv[1][0] == '-') { | ||||
|         printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); | ||||
|         printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]); | ||||
|         return 1 ; | ||||
|     } | ||||
|  | ||||
|     int n_parallel = 1; | ||||
|  | ||||
|     if (argc >= 2) { | ||||
|         params.model = argv[1]; | ||||
|     } | ||||
| @@ -22,6 +24,10 @@ int main(int argc, char ** argv) { | ||||
|         params.prompt = argv[2]; | ||||
|     } | ||||
|  | ||||
|     if (argc >= 4) { | ||||
|         n_parallel = std::atoi(argv[3]); | ||||
|     } | ||||
|  | ||||
|     if (params.prompt.empty()) { | ||||
|         params.prompt = "Hello my name is"; | ||||
|     } | ||||
|   | ||||
| @@ -134,7 +134,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|         while (true) { | ||||
|             // sample from the target model | ||||
|             const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); | ||||
|             llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); | ||||
|  | ||||
|             // remember which tokens were sampled - used for repetition penalties during sampling | ||||
|             last_tokens.erase(last_tokens.begin()); | ||||
|   | ||||
							
								
								
									
										30
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -7356,7 +7356,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi | ||||
|  | ||||
| int llama_eval( | ||||
|         struct llama_context * ctx, | ||||
|            const llama_token * tokens, | ||||
|                  llama_token * tokens, | ||||
|                     uint32_t   n_tokens, | ||||
|                          int   n_past, | ||||
|                          int   n_threads) { | ||||
| @@ -7376,7 +7376,7 @@ int llama_eval( | ||||
|  | ||||
| int llama_eval_embd( | ||||
|             struct llama_context * ctx, | ||||
|                      const float * embd, | ||||
|                            float * embd, | ||||
|                         uint32_t   n_tokens, | ||||
|                              int   n_past, | ||||
|                              int   n_threads) { | ||||
| @@ -7397,7 +7397,7 @@ int llama_eval_embd( | ||||
| } | ||||
|  | ||||
| struct llama_batch llama_batch_get_one( | ||||
|        const llama_token * tokens, | ||||
|              llama_token * tokens, | ||||
|                 uint32_t   n_tokens, | ||||
|                llama_pos   pos_0, | ||||
|             llama_seq_id   seq_id) { | ||||
| @@ -7414,6 +7414,30 @@ struct llama_batch llama_batch_get_one( | ||||
|     }; | ||||
| } | ||||
|  | ||||
| struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd) { | ||||
|     llama_batch batch = { n_tokens, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; | ||||
|  | ||||
|     if (embd) { | ||||
|         batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); | ||||
|     } else { | ||||
|         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); | ||||
|     } | ||||
|  | ||||
|     batch.pos    = (llama_pos *)    malloc(sizeof(llama_pos)    * n_tokens); | ||||
|     batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); | ||||
|     batch.logits = (int8_t *)       malloc(sizeof(int8_t)       * n_tokens); | ||||
|  | ||||
|     return batch; | ||||
| } | ||||
|  | ||||
| void llama_batch_free(struct llama_batch batch) { | ||||
|     if (batch.token)  free(batch.token); | ||||
|     if (batch.embd)   free(batch.embd); | ||||
|     if (batch.pos)    free(batch.pos); | ||||
|     if (batch.seq_id) free(batch.seq_id); | ||||
|     if (batch.logits) free(batch.logits); | ||||
| } | ||||
|  | ||||
| int llama_decode( | ||||
|         struct llama_context * ctx, | ||||
|           struct llama_batch   batch, | ||||
|   | ||||
							
								
								
									
										32
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								llama.h
									
									
									
									
									
								
							| @@ -70,11 +70,11 @@ extern "C" { | ||||
|     typedef struct llama_batch { | ||||
|         uint32_t n_tokens; | ||||
|  | ||||
|         const llama_token  * token; | ||||
|         const float        * embd; | ||||
|         const llama_pos    * pos; | ||||
|         const llama_seq_id * seq_id; | ||||
|         const int8_t       * logits; // if 0, do not extract logits for that token | ||||
|         llama_token  * token; | ||||
|         float        * embd; | ||||
|         llama_pos    * pos; | ||||
|         llama_seq_id * seq_id; | ||||
|         int8_t       * logits; // if 0, do not extract logits for that token | ||||
|  | ||||
|         // NOTE: helpers for smooth API transition - can be deprecated in the future | ||||
|         //       for future-proof code, use the above fields instead and ignore everything below | ||||
| @@ -84,7 +84,7 @@ extern "C" { | ||||
|         llama_pos    all_pos_0;  // used if pos == NULL | ||||
|         llama_pos    all_pos_1;  // used if pos == NULL | ||||
|         llama_seq_id all_seq_id; // used if seq_id == NULL | ||||
|     } llama_seq; | ||||
|     } llama_batch; | ||||
|  | ||||
|     enum llama_log_level { | ||||
|         LLAMA_LOG_LEVEL_ERROR = 2, | ||||
| @@ -366,34 +366,46 @@ extern "C" { | ||||
|     // tokens + n_tokens is the provided batch of new tokens to process | ||||
|     // n_past is the number of tokens to use from previous eval calls | ||||
|     // Returns 0 on success | ||||
|     // DEPRECATED: use llama_decode() instead | ||||
|     LLAMA_API DEPRECATED(int llama_eval( | ||||
|             struct llama_context * ctx, | ||||
|                const llama_token * tokens, | ||||
|                      llama_token * tokens, | ||||
|                         uint32_t   n_tokens, | ||||
|                              int   n_past, | ||||
|                              int   n_threads), | ||||
|             "please use llama_decode() instead"); | ||||
|  | ||||
|     // Same as llama_eval, but use float matrix input directly. | ||||
|     // DEPRECATED: use llama_decode() instead | ||||
|     LLAMA_API DEPRECATED(int llama_eval_embd( | ||||
|             struct llama_context * ctx, | ||||
|                      const float * embd, | ||||
|                            float * embd, | ||||
|                         uint32_t   n_tokens, | ||||
|                              int   n_past, | ||||
|                              int   n_threads), | ||||
|             "please use llama_decode() instead"); | ||||
|  | ||||
|     // Return batch for single sequence of tokens starting at pos_0 | ||||
|     // If pos_0 == 0, the clear_kv flag will be auto set to true | ||||
|     // | ||||
|     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it | ||||
|     // | ||||
|     LLAMA_API struct llama_batch llama_batch_get_one( | ||||
|             const llama_token * tokens, | ||||
|                   llama_token * tokens, | ||||
|                      uint32_t   n_tokens, | ||||
|                     llama_pos   pos_0, | ||||
|                  llama_seq_id   seq_id); | ||||
|  | ||||
|     // Allocates a batch of tokens on the heap | ||||
|     // The batch needs to be freed with llama_batch_free() | ||||
|     // If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) | ||||
|     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | ||||
|     // The rest of the llama_batch members are allocated with size n_tokens | ||||
|     // All members are left uninitialized | ||||
|     LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd); | ||||
|  | ||||
|     // Frees a batch of tokens allocated with llama_batch_init() | ||||
|     LLAMA_API void llama_batch_free(struct llama_batch batch); | ||||
|  | ||||
|     // Positive return values does not mean a fatal error, but rather a warning. | ||||
|     //   0 - success | ||||
|     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov