mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : extend batch API to select which logits to output
This commit is contained in:
		| @@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){ | ||||
|         if (n_eval > n_batch) { | ||||
|             n_eval = n_batch; | ||||
|         } | ||||
|         llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, }; | ||||
|         llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; | ||||
|         if (llama_decode(ctx, batch, params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|             return false; | ||||
|   | ||||
| @@ -82,6 +82,9 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     const int n_clients = 4; | ||||
|  | ||||
|     // insert new requests as soon as the previous one is done | ||||
|     const bool hot_swap = true; | ||||
|  | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     log_set_target(log_filename_generator("parallel", "log")); | ||||
|     LOG_TEE("Log start\n"); | ||||
| @@ -121,14 +124,23 @@ int main(int argc, char ** argv) { | ||||
|     std::vector<llama_token>  batch_token; | ||||
|     std::vector<llama_pos>    batch_pos; | ||||
|     std::vector<llama_seq_id> batch_seq_id; | ||||
|     std::vector<int8_t>       batch_logits; | ||||
|     std::vector<client *>     batch_clients; | ||||
|  | ||||
|     while (true) { | ||||
|     int32_t n_total_prompt = 0; | ||||
|     int32_t n_total_gen    = 0; | ||||
|  | ||||
|     float t_avg = 0.0f; | ||||
|  | ||||
|     const int32_t n_seq = 128; | ||||
|  | ||||
|     while (g_seq_id < n_seq + n_clients) { | ||||
|         uint32_t n_tokens = 0; | ||||
|  | ||||
|         batch_token.clear(); | ||||
|         batch_pos.clear(); | ||||
|         batch_seq_id.clear(); | ||||
|         batch_logits.clear(); | ||||
|  | ||||
|         for (auto & client : clients) { | ||||
|             if (client.seq_id == -1) { | ||||
| @@ -138,6 +150,7 @@ int main(int argc, char ** argv) { | ||||
|             batch_token.push_back(client.sampled); | ||||
|             batch_pos.push_back(client.n_decoded); | ||||
|             batch_seq_id.push_back(client.seq_id); | ||||
|             batch_logits.push_back(true); | ||||
|             batch_clients.push_back(&client); | ||||
|             client.n_decoded += 1; | ||||
|             client.i_batch = batch_token.size() - 1; | ||||
| @@ -146,7 +159,9 @@ int main(int argc, char ** argv) { | ||||
|         if (batch_token.empty()) { | ||||
|             // all sequences have ended - clear the entire KV cache | ||||
|             llama_kv_cache_rm_tokens(ctx, -1, -1); | ||||
|         } | ||||
|  | ||||
|         if (hot_swap || batch_token.empty()) { | ||||
|             for (auto & client : clients) { | ||||
|                 if (client.seq_id == -1) { | ||||
|                     client.seq_id = g_seq_id; | ||||
| @@ -166,7 +181,10 @@ int main(int argc, char ** argv) { | ||||
|                         batch_pos.push_back(i); | ||||
|                         batch_seq_id.push_back(client.seq_id); | ||||
|                         batch_clients.push_back(&client); | ||||
|                         batch_logits.push_back(false); | ||||
|                     } | ||||
|                     batch_logits.back() = true; | ||||
|  | ||||
|                     client.n_prompt  = prompt_tokens.size(); | ||||
|                     client.n_decoded = prompt_tokens.size(); | ||||
|                     client.i_batch   = batch_token.size() - 1; | ||||
| @@ -186,6 +204,7 @@ int main(int argc, char ** argv) { | ||||
|                 nullptr, | ||||
|                 batch_pos.data() + i, | ||||
|                 batch_seq_id.data() + i, | ||||
|                 batch_logits.data() + i, | ||||
|                 0, 0, 0, // unused | ||||
|             }; | ||||
|  | ||||
| @@ -232,14 +251,20 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                     const auto t_main_end = ggml_time_us(); | ||||
|  | ||||
|                     printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput:    %s\nResponse: %s\n\n", | ||||
|                     printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput:    %s\nResponse: %s\n\n", | ||||
|                             client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, | ||||
|                             (t_main_end - client.t_start_prompt) / 1e6, | ||||
|                             (double) (client.n_prompt                   ) / (client.t_start_gen - client.t_start_prompt) * 1e6, | ||||
|                             (double) (client.n_decoded - client.n_prompt) / (t_main_end         - client.t_start_gen)    * 1e6, | ||||
|                             (double) (client.n_decoded                  ) / (t_main_end         - client.t_start_prompt) * 1e6, | ||||
|                             ::trim(client.input).c_str(), | ||||
|                             ::trim(client.response).c_str()); | ||||
|  | ||||
|                     n_total_prompt += client.n_prompt; | ||||
|                     n_total_gen    += client.n_decoded - client.n_prompt; | ||||
|  | ||||
|                     t_avg += (t_main_end - client.t_start_prompt) / 1e6; | ||||
|  | ||||
|                     client.seq_id = -1; | ||||
|                 } | ||||
|  | ||||
| @@ -248,6 +273,11 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     LOG_TEE("\n\n"); | ||||
|     LOG_TEE("Total prompt tokens: %d\n", n_total_prompt); | ||||
|     LOG_TEE("Total gen tokens:    %d\n", n_total_gen); | ||||
|     LOG_TEE("Avg time per seq:    %.2f s\n", t_avg / n_seq); | ||||
|  | ||||
|     LOG_TEE("\n\n"); | ||||
|  | ||||
|     llama_print_timings(ctx); | ||||
|   | ||||
							
								
								
									
										14
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -4140,7 +4140,16 @@ static bool llama_eval_internal( | ||||
|  | ||||
|         if (lctx.logits_all) { | ||||
|             logits_out.resize(n_vocab * n_tokens); | ||||
|             memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); | ||||
|             if (batch.logits) { | ||||
|                 for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|                     if (batch.logits[i] == 0) { | ||||
|                         continue; | ||||
|                     } | ||||
|                     memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); | ||||
|                 } | ||||
|             } else { | ||||
|                 memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); | ||||
|             } | ||||
|         } else { | ||||
|             // return result for just the last token | ||||
|             logits_out.resize(n_vocab); | ||||
| @@ -7318,7 +7327,7 @@ int llama_eval_embd( | ||||
|                              int   n_threads) { | ||||
|     llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); | ||||
|  | ||||
|     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, }; | ||||
|     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; | ||||
|  | ||||
|     if (!llama_eval_internal(*ctx, batch, n_threads)) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||
| @@ -7346,6 +7355,7 @@ struct llama_batch llama_batch_get_one( | ||||
|         /*embd        =*/ nullptr, | ||||
|         /*pos         =*/ nullptr, | ||||
|         /*seq_id      =*/ nullptr, | ||||
|         /*logits      =*/ nullptr, | ||||
|         /*all_pos_0   =*/ pos_0, | ||||
|         /*all_pos_1   =*/ 1, | ||||
|         /*all_seq_id  =*/ seq_id, | ||||
|   | ||||
							
								
								
									
										2
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								llama.h
									
									
									
									
									
								
							| @@ -70,11 +70,11 @@ extern "C" { | ||||
|     typedef struct llama_batch { | ||||
|         uint32_t n_tokens; | ||||
|  | ||||
|         // TODO: not sure about these consts - might just get in the way all the time with no benefit | ||||
|         const llama_token  * token; | ||||
|         const float        * embd; | ||||
|         const llama_pos    * pos; | ||||
|         const llama_seq_id * seq_id; | ||||
|         const int8_t       * logits; // if 0, do not extract logits for that token | ||||
|  | ||||
|         // NOTE: helpers for smooth API transition - can be deprecated in the future | ||||
|         //       for future-proof code, use the above fields instead and ignore everything below | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov