mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	simple : add parallel decoding support
This commit is contained in:
		| @@ -956,11 +956,11 @@ llama_token llama_sample_token( | |||||||
|         if (mirostat == 1) { |         if (mirostat == 1) { | ||||||
|             static float mirostat_mu = 2.0f * mirostat_tau; |             static float mirostat_mu = 2.0f * mirostat_tau; | ||||||
|             const int mirostat_m = 100; |             const int mirostat_m = 100; | ||||||
|             llama_sample_temperature(ctx, &cur_p, temp); |             llama_sample_temp(ctx, &cur_p, temp); | ||||||
|             id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); |             id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||||||
|         } else if (mirostat == 2) { |         } else if (mirostat == 2) { | ||||||
|             static float mirostat_mu = 2.0f * mirostat_tau; |             static float mirostat_mu = 2.0f * mirostat_tau; | ||||||
|             llama_sample_temperature(ctx, &cur_p, temp); |             llama_sample_temp(ctx, &cur_p, temp); | ||||||
|             id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); |             id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||||||
|         } else { |         } else { | ||||||
|             // Temperature sampling |             // Temperature sampling | ||||||
| @@ -968,7 +968,7 @@ llama_token llama_sample_token( | |||||||
|             llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1); |             llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1); | ||||||
|             llama_sample_typical    (ctx, &cur_p, typical_p, 1); |             llama_sample_typical    (ctx, &cur_p, typical_p, 1); | ||||||
|             llama_sample_top_p      (ctx, &cur_p, top_p, 1); |             llama_sample_top_p      (ctx, &cur_p, top_p, 1); | ||||||
|             llama_sample_temperature(ctx, &cur_p, temp); |             llama_sample_temp(ctx, &cur_p, temp); | ||||||
|  |  | ||||||
|             { |             { | ||||||
|                 const int n_top = 10; |                 const int n_top = 10; | ||||||
|   | |||||||
| @@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){ | |||||||
|         if (n_eval > n_batch) { |         if (n_eval > n_batch) { | ||||||
|             n_eval = n_batch; |             n_eval = n_batch; | ||||||
|         } |         } | ||||||
|         llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; |         llama_batch batch = {  int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; | ||||||
|         if (llama_decode(ctx, batch, params.n_threads)) { |         if (llama_decode(ctx, batch, params.n_threads)) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return false; |             return false; | ||||||
| @@ -183,11 +183,11 @@ llama_token sampling_id(struct MyModel* mymodel) { | |||||||
|             if (mirostat == 1) { |             if (mirostat == 1) { | ||||||
|                 static float mirostat_mu = 2.0f * mirostat_tau; |                 static float mirostat_mu = 2.0f * mirostat_tau; | ||||||
|                 const int mirostat_m = 100; |                 const int mirostat_m = 100; | ||||||
|                 llama_sample_temperature(ctx, &candidates_p, temp); |                 llama_sample_temp(ctx, &candidates_p, temp); | ||||||
|                 id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); |                 id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||||||
|             } else if (mirostat == 2) { |             } else if (mirostat == 2) { | ||||||
|                 static float mirostat_mu = 2.0f * mirostat_tau; |                 static float mirostat_mu = 2.0f * mirostat_tau; | ||||||
|                 llama_sample_temperature(ctx, &candidates_p, temp); |                 llama_sample_temp(ctx, &candidates_p, temp); | ||||||
|                 id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); |                 id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||||||
|             } else { |             } else { | ||||||
|                 // Temperature sampling |                 // Temperature sampling | ||||||
| @@ -195,7 +195,7 @@ llama_token sampling_id(struct MyModel* mymodel) { | |||||||
|                 llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); |                 llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); | ||||||
|                 llama_sample_typical(ctx, &candidates_p, typical_p, 1); |                 llama_sample_typical(ctx, &candidates_p, typical_p, 1); | ||||||
|                 llama_sample_top_p(ctx, &candidates_p, top_p, 1); |                 llama_sample_top_p(ctx, &candidates_p, top_p, 1); | ||||||
|                 llama_sample_temperature(ctx, &candidates_p, temp); |                 llama_sample_temp(ctx, &candidates_p, temp); | ||||||
|                 id = llama_sample_token(ctx, &candidates_p); |                 id = llama_sample_token(ctx, &candidates_p); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -123,7 +123,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     std::vector<llama_token> tokens_system; |     std::vector<llama_token> tokens_system; | ||||||
|     tokens_system = ::llama_tokenize(ctx, k_system, true); |     tokens_system = ::llama_tokenize(ctx, k_system, true); | ||||||
|     const uint32_t n_tokens_system = tokens_system.size(); |     const int32_t n_tokens_system = tokens_system.size(); | ||||||
|  |  | ||||||
|     llama_seq_id g_seq_id = 0; |     llama_seq_id g_seq_id = 0; | ||||||
|  |  | ||||||
| @@ -144,7 +144,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|         batch.n_tokens = n_tokens_system; |         batch.n_tokens = n_tokens_system; | ||||||
|  |  | ||||||
|         for (uint32_t i = 0; i < batch.n_tokens; ++i) { |         for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||||
|             batch.token[i]  = tokens_system[i]; |             batch.token[i]  = tokens_system[i]; | ||||||
|             batch.pos[i]    = i; |             batch.pos[i]    = i; | ||||||
|             batch.seq_id[i] = 0; |             batch.seq_id[i] = 0; | ||||||
| @@ -156,7 +156,7 @@ int main(int argc, char ** argv) { | |||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // assign the system KV cachce to all parallel sequences |         // assign the system KV cache to all parallel sequences | ||||||
|         for (int32_t i = 1; i < n_clients; ++i) { |         for (int32_t i = 1; i < n_clients; ++i) { | ||||||
|             llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); |             llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); | ||||||
|         } |         } | ||||||
| @@ -248,7 +248,7 @@ int main(int argc, char ** argv) { | |||||||
|         int32_t n_batch = params.n_batch; |         int32_t n_batch = params.n_batch; | ||||||
|  |  | ||||||
|         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { |         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { | ||||||
|             const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); |             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); | ||||||
|  |  | ||||||
|             llama_batch batch_view = { |             llama_batch batch_view = { | ||||||
|                 n_tokens, |                 n_tokens, | ||||||
|   | |||||||
| @@ -523,13 +523,13 @@ struct llama_server_context | |||||||
|                 { |                 { | ||||||
|                     static float mirostat_mu = 2.0f * mirostat_tau; |                     static float mirostat_mu = 2.0f * mirostat_tau; | ||||||
|                     const int mirostat_m = 100; |                     const int mirostat_m = 100; | ||||||
|                     llama_sample_temperature(ctx, &candidates_p, temp); |                     llama_sample_temp(ctx, &candidates_p, temp); | ||||||
|                     result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); |                     result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||||||
|                 } |                 } | ||||||
|                 else if (mirostat == 2) |                 else if (mirostat == 2) | ||||||
|                 { |                 { | ||||||
|                     static float mirostat_mu = 2.0f * mirostat_tau; |                     static float mirostat_mu = 2.0f * mirostat_tau; | ||||||
|                     llama_sample_temperature(ctx, &candidates_p, temp); |                     llama_sample_temp(ctx, &candidates_p, temp); | ||||||
|                     result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); |                     result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||||||
|                 } |                 } | ||||||
|                 else |                 else | ||||||
| @@ -540,7 +540,7 @@ struct llama_server_context | |||||||
|                     llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); |                     llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); | ||||||
|                     llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); |                     llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); | ||||||
|                     llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); |                     llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); | ||||||
|                     llama_sample_temperature(ctx, &candidates_p, temp); |                     llama_sample_temp(ctx, &candidates_p, temp); | ||||||
|                     result.tok = llama_sample_token(ctx, &candidates_p); |                     result.tok = llama_sample_token(ctx, &candidates_p); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -32,12 +32,18 @@ int main(int argc, char ** argv) { | |||||||
|         params.prompt = "Hello my name is"; |         params.prompt = "Hello my name is"; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // total length of the sequences including the prompt | ||||||
|  |     const int n_len = 32; | ||||||
|  |  | ||||||
|     // init LLM |     // init LLM | ||||||
|  |  | ||||||
|     llama_backend_init(params.numa); |     llama_backend_init(params.numa); | ||||||
|  |  | ||||||
|     llama_context_params ctx_params = llama_context_default_params(); |     llama_context_params ctx_params = llama_context_default_params(); | ||||||
|  |  | ||||||
|  |     ctx_params.seed  = 1234; | ||||||
|  |     ctx_params.n_ctx = 2048; | ||||||
|  |  | ||||||
|     llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); |     llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); | ||||||
|  |  | ||||||
|     if (model == NULL) { |     if (model == NULL) { | ||||||
| @@ -47,20 +53,29 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     llama_context * ctx = llama_new_context_with_model(model, ctx_params); |     llama_context * ctx = llama_new_context_with_model(model, ctx_params); | ||||||
|  |  | ||||||
|  |     if (ctx == NULL) { | ||||||
|  |         fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); | ||||||
|  |         return 1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // tokenize the prompt |     // tokenize the prompt | ||||||
|  |  | ||||||
|     std::vector<llama_token> tokens_list; |     std::vector<llama_token> tokens_list; | ||||||
|     tokens_list = ::llama_tokenize(ctx, params.prompt, true); |     tokens_list = ::llama_tokenize(ctx, params.prompt, true); | ||||||
|  |  | ||||||
|     const int max_context_size     = llama_n_ctx(ctx); |     const int n_ctx    = llama_n_ctx(ctx); | ||||||
|     const int max_tokens_list_size = max_context_size - 4; |     const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; | ||||||
|  |  | ||||||
|     if ((int) tokens_list.size() > max_tokens_list_size) { |     LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req); | ||||||
|         fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); |  | ||||||
|  |     // make sure wi | ||||||
|  |     if (n_kv_req > n_ctx) { | ||||||
|  |         LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); | ||||||
|  |         LOG_TEE("%s:        either reduce n_parallel or increase n_ctx\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fprintf(stderr, "\n\n"); |     fprintf(stderr, "\n"); | ||||||
|  |  | ||||||
|     for (auto id : tokens_list) { |     for (auto id : tokens_list) { | ||||||
|         fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); |         fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); | ||||||
| @@ -68,66 +83,157 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     fflush(stderr); |     fflush(stderr); | ||||||
|  |  | ||||||
|  |     // create a llama_batch with size 512 | ||||||
|  |     // we use this object to submit token data for decoding | ||||||
|  |  | ||||||
|  |     llama_batch batch = llama_batch_init(512, 0); | ||||||
|  |  | ||||||
|  |     // evaluate the initial prompt | ||||||
|  |     batch.n_tokens = tokens_list.size(); | ||||||
|  |  | ||||||
|  |     for (int32_t i = 0; i < batch.n_tokens; i++) { | ||||||
|  |         batch.token[i]  = tokens_list[i]; | ||||||
|  |         batch.pos[i]    = i; | ||||||
|  |         batch.seq_id[i] = 0; | ||||||
|  |         batch.logits[i] = false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // llama_decode will output logits only for the last token of the prompt | ||||||
|  |     batch.logits[batch.n_tokens - 1] = true; | ||||||
|  |  | ||||||
|  |     if (llama_decode(ctx, batch, params.n_threads) != 0) { | ||||||
|  |         LOG_TEE("%s: llama_decode() failed\n", __func__); | ||||||
|  |         return 1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // assign the system KV cache to all parallel sequences | ||||||
|  |     for (int32_t i = 1; i < n_parallel; ++i) { | ||||||
|  |         llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (n_parallel > 1) { | ||||||
|  |         LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // main loop |     // main loop | ||||||
|  |  | ||||||
|     // The LLM keeps a contextual cache memory of previous token evaluation. |     // we will store the parallel decoded sequences in this vector | ||||||
|     // Usually, once this cache is full, it is required to recompute a compressed context based on previous |     std::vector<std::string> streams(n_parallel); | ||||||
|     // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist |  | ||||||
|     // example, we will just stop the loop once this cache is full or once an end of stream is detected. |  | ||||||
|  |  | ||||||
|     const int n_gen = std::min(32, max_context_size); |     // remember the batch index of the last tokenn for each parallel sequence | ||||||
|  |     // we will use this to know which logits to sample from | ||||||
|  |     std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); | ||||||
|  |  | ||||||
|     int n_cur = 0; |     int n_cur    = batch.n_tokens; | ||||||
|  |     int n_decode = 0; | ||||||
|  |  | ||||||
|     while (n_cur < n_gen) { |     const auto t_main_start = ggml_time_us(); | ||||||
|         // evaluate the transformer |  | ||||||
|  |  | ||||||
|         if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) { |     while (n_cur <= n_len) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |         // evaluate the current batch with the transformer model | ||||||
|  |         if (llama_decode(ctx, batch, params.n_threads)) { | ||||||
|  |             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         n_cur += tokens_list.size(); |         // prepare the next batch | ||||||
|         tokens_list.clear(); |         batch.n_tokens = 0; | ||||||
|  |  | ||||||
|         // sample the next token |         // sample the next token for each parallel sequence / stream | ||||||
|  |         for (int32_t i = 0; i < n_parallel; ++i) { | ||||||
|  |             if (i_batch[i] < 0) { | ||||||
|  |                 // the stream has already finished | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |  | ||||||
|         llama_token new_token_id = 0; |             auto n_vocab = llama_n_vocab(ctx); | ||||||
|  |             auto logits  = llama_get_logits(ctx) + i_batch[i] * n_vocab; | ||||||
|  |  | ||||||
|         auto logits  = llama_get_logits(ctx); |             std::vector<llama_token_data> candidates; | ||||||
|         auto n_vocab = llama_n_vocab(ctx); |             candidates.reserve(n_vocab); | ||||||
|  |  | ||||||
|         std::vector<llama_token_data> candidates; |             for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||||
|         candidates.reserve(n_vocab); |                 candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); | ||||||
|  |             } | ||||||
|  |  | ||||||
|         for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |             llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||||
|             candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); |  | ||||||
|  |             const int   top_k = 40; | ||||||
|  |             const float top_p = 0.9f; | ||||||
|  |             const float temp  = 0.4f; | ||||||
|  |  | ||||||
|  |             llama_sample_top_k(ctx, &candidates_p, top_k, 1); | ||||||
|  |             llama_sample_top_p(ctx, &candidates_p, top_p, 1); | ||||||
|  |             llama_sample_temp (ctx, &candidates_p, temp); | ||||||
|  |  | ||||||
|  |             const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); | ||||||
|  |  | ||||||
|  |             //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); | ||||||
|  |  | ||||||
|  |             // is it an end of stream ? | ||||||
|  |             // mark this stream as finished | ||||||
|  |             if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { | ||||||
|  |                 i_batch[i] = -1; | ||||||
|  |                 LOG_TEE("\n"); | ||||||
|  |                 if (n_parallel > 1) { | ||||||
|  |                     LOG_TEE("%s: stream %d finished", __func__, i); | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (n_parallel == 1) { | ||||||
|  |                 // print the new token : | ||||||
|  |                 LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); | ||||||
|  |                 fflush(stdout); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             streams[i] += llama_token_to_piece(ctx, new_token_id); | ||||||
|  |  | ||||||
|  |             // push this new token for next evaluation | ||||||
|  |             batch.token [batch.n_tokens] = new_token_id; | ||||||
|  |             batch.pos   [batch.n_tokens] = n_cur; | ||||||
|  |             batch.seq_id[batch.n_tokens] = i; | ||||||
|  |             batch.logits[batch.n_tokens] = true; | ||||||
|  |  | ||||||
|  |             i_batch[i] = batch.n_tokens; | ||||||
|  |  | ||||||
|  |             batch.n_tokens += 1; | ||||||
|  |  | ||||||
|  |             n_decode += 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; |         if (batch.n_tokens == 0) { | ||||||
|  |             // all streams are finished | ||||||
|         new_token_id = llama_sample_token_greedy(ctx , &candidates_p); |  | ||||||
|  |  | ||||||
|         // is it an end of stream ? |  | ||||||
|         if (new_token_id == llama_token_eos(ctx)) { |  | ||||||
|             fprintf(stderr, " [end of text]\n"); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // print the new token : |         n_cur += 1; | ||||||
|         printf("%s", llama_token_to_piece(ctx, new_token_id).c_str()); |  | ||||||
|         fflush(stdout); |  | ||||||
|  |  | ||||||
|         // push this new token for next evaluation |  | ||||||
|         tokens_list.push_back(new_token_id); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     LOG_TEE("\n"); | ||||||
|  |  | ||||||
|  |     if (n_parallel > 1) { | ||||||
|  |         LOG_TEE("\n"); | ||||||
|  |  | ||||||
|  |         for (int32_t i = 0; i < n_parallel; ++i) { | ||||||
|  |             LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const auto t_main_end = ggml_time_us(); | ||||||
|  |  | ||||||
|  |     LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", | ||||||
|  |             __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); | ||||||
|  |  | ||||||
|  |     llama_print_timings(ctx); | ||||||
|  |  | ||||||
|  |     fprintf(stderr, "\n"); | ||||||
|  |  | ||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|     llama_free_model(model); |     llama_free_model(model); | ||||||
|  |  | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|     fprintf(stderr, "\n\n"); |  | ||||||
|  |  | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -4185,20 +4185,18 @@ static int llama_decode_internal( | |||||||
|     { |     { | ||||||
|         auto & logits_out = lctx.logits; |         auto & logits_out = lctx.logits; | ||||||
|  |  | ||||||
|         if (lctx.logits_all) { |         if (batch.logits) { | ||||||
|             logits_out.resize(n_vocab * n_tokens); |             logits_out.resize(n_vocab * n_tokens); | ||||||
|             if (batch.logits) { |             for (uint32_t i = 0; i < n_tokens; i++) { | ||||||
|                 for (uint32_t i = 0; i < n_tokens; i++) { |                 if (batch.logits[i] == 0) { | ||||||
|                     if (batch.logits[i] == 0) { |                     continue; | ||||||
|                         continue; |  | ||||||
|                     } |  | ||||||
|                     memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); |  | ||||||
|                 } |                 } | ||||||
|             } else { |                 memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); | ||||||
|                 memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); |  | ||||||
|             } |             } | ||||||
|  |         } else if (lctx.logits_all) { | ||||||
|  |             logits_out.resize(n_vocab * n_tokens); | ||||||
|  |             memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); | ||||||
|         } else { |         } else { | ||||||
|             // return result for just the last token |  | ||||||
|             logits_out.resize(n_vocab); |             logits_out.resize(n_vocab); | ||||||
|             memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); |             memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); | ||||||
|         } |         } | ||||||
| @@ -5269,7 +5267,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { | void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { | ||||||
|     const int64_t t_start_sample_us = ggml_time_us(); |     const int64_t t_start_sample_us = ggml_time_us(); | ||||||
|  |  | ||||||
|     for (size_t i = 0; i < candidates_p->size; ++i) { |     for (size_t i = 0; i < candidates_p->size; ++i) { | ||||||
| @@ -5281,6 +5279,10 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { | ||||||
|  |     llama_sample_temp(ctx, candidates_p, temp); | ||||||
|  | } | ||||||
|  |  | ||||||
| void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { | void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { | ||||||
|     if (last_tokens_size == 0 || penalty == 1.0f) { |     if (last_tokens_size == 0 || penalty == 1.0f) { | ||||||
|         return; |         return; | ||||||
| @@ -7357,7 +7359,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi | |||||||
| int llama_eval( | int llama_eval( | ||||||
|         struct llama_context * ctx, |         struct llama_context * ctx, | ||||||
|                  llama_token * tokens, |                  llama_token * tokens, | ||||||
|                     uint32_t   n_tokens, |                      int32_t   n_tokens, | ||||||
|                          int   n_past, |                          int   n_past, | ||||||
|                          int   n_threads) { |                          int   n_threads) { | ||||||
|     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); |     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); | ||||||
| @@ -7377,7 +7379,7 @@ int llama_eval( | |||||||
| int llama_eval_embd( | int llama_eval_embd( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                            float * embd, |                            float * embd, | ||||||
|                         uint32_t   n_tokens, |                          int32_t   n_tokens, | ||||||
|                              int   n_past, |                              int   n_past, | ||||||
|                              int   n_threads) { |                              int   n_threads) { | ||||||
|     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); |     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); | ||||||
| @@ -7398,7 +7400,7 @@ int llama_eval_embd( | |||||||
|  |  | ||||||
| struct llama_batch llama_batch_get_one( | struct llama_batch llama_batch_get_one( | ||||||
|              llama_token * tokens, |              llama_token * tokens, | ||||||
|                 uint32_t   n_tokens, |                  int32_t   n_tokens, | ||||||
|                llama_pos   pos_0, |                llama_pos   pos_0, | ||||||
|             llama_seq_id   seq_id) { |             llama_seq_id   seq_id) { | ||||||
|     return { |     return { | ||||||
| @@ -7414,8 +7416,8 @@ struct llama_batch llama_batch_get_one( | |||||||
|     }; |     }; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd) { | struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { | ||||||
|     llama_batch batch = { n_tokens, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; |     llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; | ||||||
|  |  | ||||||
|     if (embd) { |     if (embd) { | ||||||
|         batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); |         batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); | ||||||
|   | |||||||
							
								
								
									
										15
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								llama.h
									
									
									
									
									
								
							| @@ -68,7 +68,7 @@ extern "C" { | |||||||
|  |  | ||||||
|     // data used for batch inference |     // data used for batch inference | ||||||
|     typedef struct llama_batch { |     typedef struct llama_batch { | ||||||
|         uint32_t n_tokens; |         int32_t n_tokens; | ||||||
|  |  | ||||||
|         llama_token  * token; |         llama_token  * token; | ||||||
|         float        * embd; |         float        * embd; | ||||||
| @@ -370,7 +370,7 @@ extern "C" { | |||||||
|     LLAMA_API DEPRECATED(int llama_eval( |     LLAMA_API DEPRECATED(int llama_eval( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                      llama_token * tokens, |                      llama_token * tokens, | ||||||
|                         uint32_t   n_tokens, |                          int32_t   n_tokens, | ||||||
|                              int   n_past, |                              int   n_past, | ||||||
|                              int   n_threads), |                              int   n_threads), | ||||||
|             "please use llama_decode() instead"); |             "please use llama_decode() instead"); | ||||||
| @@ -380,7 +380,7 @@ extern "C" { | |||||||
|     LLAMA_API DEPRECATED(int llama_eval_embd( |     LLAMA_API DEPRECATED(int llama_eval_embd( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                            float * embd, |                            float * embd, | ||||||
|                         uint32_t   n_tokens, |                          int32_t   n_tokens, | ||||||
|                              int   n_past, |                              int   n_past, | ||||||
|                              int   n_threads), |                              int   n_threads), | ||||||
|             "please use llama_decode() instead"); |             "please use llama_decode() instead"); | ||||||
| @@ -391,7 +391,7 @@ extern "C" { | |||||||
|     // |     // | ||||||
|     LLAMA_API struct llama_batch llama_batch_get_one( |     LLAMA_API struct llama_batch llama_batch_get_one( | ||||||
|                   llama_token * tokens, |                   llama_token * tokens, | ||||||
|                      uint32_t   n_tokens, |                       int32_t   n_tokens, | ||||||
|                     llama_pos   pos_0, |                     llama_pos   pos_0, | ||||||
|                  llama_seq_id   seq_id); |                  llama_seq_id   seq_id); | ||||||
|  |  | ||||||
| @@ -401,7 +401,7 @@ extern "C" { | |||||||
|     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token |     // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | ||||||
|     // The rest of the llama_batch members are allocated with size n_tokens |     // The rest of the llama_batch members are allocated with size n_tokens | ||||||
|     // All members are left uninitialized |     // All members are left uninitialized | ||||||
|     LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd); |     LLAMA_API struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd); | ||||||
|  |  | ||||||
|     // Frees a batch of tokens allocated with llama_batch_init() |     // Frees a batch of tokens allocated with llama_batch_init() | ||||||
|     LLAMA_API void llama_batch_free(struct llama_batch batch); |     LLAMA_API void llama_batch_free(struct llama_batch batch); | ||||||
| @@ -531,7 +531,10 @@ extern "C" { | |||||||
|  |  | ||||||
|     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. |     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | ||||||
|     LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); |     LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); | ||||||
|     LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); |     LLAMA_API void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates, float temp); | ||||||
|  |  | ||||||
|  |     LLAMA_API DEPRECATED(void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp), | ||||||
|  |             "Use llama_sample_temp instead"); | ||||||
|  |  | ||||||
|     /// @details Apply constraints from grammar |     /// @details Apply constraints from grammar | ||||||
|     LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); |     LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov