mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : add new llama_decode() API that works with llama_batch
This commit is contained in:
		| @@ -780,7 +780,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par | |||||||
|         LOG("warming up the model with an empty run\n"); |         LOG("warming up the model with an empty run\n"); | ||||||
|  |  | ||||||
|         std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; |         std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; | ||||||
|         llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); |         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); | ||||||
|         llama_reset_timings(lctx); |         llama_reset_timings(lctx); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -160,7 +160,7 @@ int main(int argc, char ** argv) | |||||||
|  |  | ||||||
|     int n_past = 0; |     int n_past = 0; | ||||||
|  |  | ||||||
|     if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) |     if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads)) | ||||||
|     { |     { | ||||||
|         fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); |         fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); | ||||||
|         return 1; |         return 1; | ||||||
|   | |||||||
| @@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){ | |||||||
|         if (n_eval > n_batch) { |         if (n_eval > n_batch) { | ||||||
|             n_eval = n_batch; |             n_eval = n_batch; | ||||||
|         } |         } | ||||||
|         if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { |         llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false }; | ||||||
|  |         if (llama_decode(ctx, batch, params.n_threads)) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
| @@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) { | |||||||
|         if (n_eval > params.n_batch) { |         if (n_eval > params.n_batch) { | ||||||
|             n_eval = params.n_batch; |             n_eval = params.n_batch; | ||||||
|         } |         } | ||||||
|         if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { |         if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     while (!embd_inp.empty()) { |     while (!embd_inp.empty()) { | ||||||
|         int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); |         int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); | ||||||
|         if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) { |         if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat | |||||||
|     int n_processed = 0; |     int n_processed = 0; | ||||||
|     while (n_processed < n_prompt) { |     while (n_processed < n_prompt) { | ||||||
|         int n_tokens = std::min(n_prompt - n_processed, n_batch); |         int n_tokens = std::min(n_prompt - n_processed, n_batch); | ||||||
|         llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); |         llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads); | ||||||
|         n_processed += n_tokens; |         n_processed += n_tokens; | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat | |||||||
| static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { | static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { | ||||||
|     llama_token token = llama_token_bos(ctx); |     llama_token token = llama_token_bos(ctx); | ||||||
|     for (int i = 0; i < n_gen; i++) { |     for (int i = 0; i < n_gen; i++) { | ||||||
|         llama_eval(ctx, &token, 1, n_past + i, n_threads); |         llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -571,7 +571,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                 for (int i = 0; i < input_size; i += params.n_batch) { |                 for (int i = 0; i < input_size; i += params.n_batch) { | ||||||
|                     int n_eval = std::min(input_size - i, params.n_batch); |                     int n_eval = std::min(input_size - i, params.n_batch); | ||||||
|                     if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { |                     if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) { | ||||||
|                         LOG_TEE("%s : failed to eval\n", __func__); |                         LOG_TEE("%s : failed to eval\n", __func__); | ||||||
|                         return 1; |                         return 1; | ||||||
|                     } |                     } | ||||||
| @@ -588,7 +588,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                 LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); |                 LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); | ||||||
|  |  | ||||||
|                 if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { |                 if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) { | ||||||
|                     LOG_TEE("%s : failed to eval\n", __func__); |                     LOG_TEE("%s : failed to eval\n", __func__); | ||||||
|                     return 1; |                     return 1; | ||||||
|                 } |                 } | ||||||
|   | |||||||
| @@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & | |||||||
|             const int batch_size  = std::min(end - batch_start, n_batch); |             const int batch_size  = std::min(end - batch_start, n_batch); | ||||||
|  |  | ||||||
|             //fprintf(stderr, "    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); |             //fprintf(stderr, "    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); | ||||||
|             if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { |             if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { | ||||||
|                 //fprintf(stderr, "%s : failed to eval\n", __func__); |                 //fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|                 return {tokens, -1, logit_history, prob_history}; |                 return {tokens, -1, logit_history, prob_history}; | ||||||
|             } |             } | ||||||
| @@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par | |||||||
|                 tokens[batch_start] = llama_token_bos(ctx); |                 tokens[batch_start] = llama_token_bos(ctx); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { |             if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { | ||||||
|                 fprintf(stderr, "%s : failed to eval\n", __func__); |                 fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|                 return {tokens, -1, logit_history, prob_history}; |                 return {tokens, -1, logit_history, prob_history}; | ||||||
|             } |             } | ||||||
| @@ -409,7 +409,7 @@ static std::vector<float> hellaswag_evaluate_tokens( | |||||||
|     for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { |     for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { | ||||||
|         size_t n_tokens = tokens.size() - i_chunk * n_batch; |         size_t n_tokens = tokens.size() - i_chunk * n_batch; | ||||||
|         n_tokens = std::min(n_tokens, size_t(n_batch)); |         n_tokens = std::min(n_tokens, size_t(n_batch)); | ||||||
|         if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { |         if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return {}; |             return {}; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -34,11 +34,11 @@ int main(int argc, char ** argv) { | |||||||
|     auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); |     auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); | ||||||
|  |  | ||||||
|     // init |     // init | ||||||
|     auto model = llama_load_model_from_file(params.model.c_str(), lparams); |     auto * model = llama_load_model_from_file(params.model.c_str(), lparams); | ||||||
|     if (model == nullptr) { |     if (model == nullptr) { | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|     auto ctx = llama_new_context_with_model(model, lparams); |     auto * ctx = llama_new_context_with_model(model, lparams); | ||||||
|     if (ctx == nullptr) { |     if (ctx == nullptr) { | ||||||
|         llama_free_model(model); |         llama_free_model(model); | ||||||
|         return 1; |         return 1; | ||||||
| @@ -53,7 +53,7 @@ int main(int argc, char ** argv) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // evaluate prompt |     // evaluate prompt | ||||||
|     llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); |     llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); | ||||||
|  |  | ||||||
|     last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); |     last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); | ||||||
|     n_past += n_prompt_tokens; |     n_past += n_prompt_tokens; | ||||||
| @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { | |||||||
|     printf("\n%s", params.prompt.c_str()); |     printf("\n%s", params.prompt.c_str()); | ||||||
|  |  | ||||||
|     for (auto i = 0; i < params.n_predict; i++) { |     for (auto i = 0; i < params.n_predict; i++) { | ||||||
|         auto logits = llama_get_logits(ctx); |         auto * logits = llama_get_logits(ctx); | ||||||
|         auto n_vocab = llama_n_vocab(ctx); |         auto n_vocab = llama_n_vocab(ctx); | ||||||
|         std::vector<llama_token_data> candidates; |         std::vector<llama_token_data> candidates; | ||||||
|         candidates.reserve(n_vocab); |         candidates.reserve(n_vocab); | ||||||
| @@ -90,7 +90,7 @@ int main(int argc, char ** argv) { | |||||||
|         last_n_tokens_data.push_back(next_token); |         last_n_tokens_data.push_back(next_token); | ||||||
|  |  | ||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { |         if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { | ||||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); |             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||||
|             llama_free(ctx); |             llama_free(ctx); | ||||||
|             llama_free_model(model); |             llama_free_model(model); | ||||||
| @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { | |||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|  |  | ||||||
|     // make new context |     // make new context | ||||||
|     auto ctx2 = llama_new_context_with_model(model, lparams); |     auto * ctx2 = llama_new_context_with_model(model, lparams); | ||||||
|  |  | ||||||
|     // Load state (rng, logits, embedding and kv_cache) from file |     // Load state (rng, logits, embedding and kv_cache) from file | ||||||
|     { |     { | ||||||
| @@ -137,7 +137,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // second run |     // second run | ||||||
|     for (auto i = 0; i < params.n_predict; i++) { |     for (auto i = 0; i < params.n_predict; i++) { | ||||||
|         auto logits = llama_get_logits(ctx2); |         auto * logits = llama_get_logits(ctx2); | ||||||
|         auto n_vocab = llama_n_vocab(ctx2); |         auto n_vocab = llama_n_vocab(ctx2); | ||||||
|         std::vector<llama_token_data> candidates; |         std::vector<llama_token_data> candidates; | ||||||
|         candidates.reserve(n_vocab); |         candidates.reserve(n_vocab); | ||||||
| @@ -150,7 +150,7 @@ int main(int argc, char ** argv) { | |||||||
|         last_n_tokens_data.push_back(next_token); |         last_n_tokens_data.push_back(next_token); | ||||||
|  |  | ||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { |         if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { | ||||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); |             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||||
|             llama_free(ctx2); |             llama_free(ctx2); | ||||||
|             llama_free_model(model); |             llama_free_model(model); | ||||||
|   | |||||||
| @@ -434,7 +434,7 @@ struct llama_server_context | |||||||
|             { |             { | ||||||
|                 n_eval = params.n_batch; |                 n_eval = params.n_batch; | ||||||
|             } |             } | ||||||
|             if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) |             if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) | ||||||
|             { |             { | ||||||
|                 LOG_ERROR("failed to eval", { |                 LOG_ERROR("failed to eval", { | ||||||
|                                                 {"n_eval", n_eval}, |                                                 {"n_eval", n_eval}, | ||||||
|   | |||||||
| @@ -76,7 +76,7 @@ int main(int argc, char ** argv) { | |||||||
|     while (n_cur < n_gen) { |     while (n_cur < n_gen) { | ||||||
|         // evaluate the transformer |         // evaluate the transformer | ||||||
|  |  | ||||||
|         if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) { |         if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { | |||||||
|     const auto t_enc_start = ggml_time_us(); |     const auto t_enc_start = ggml_time_us(); | ||||||
|  |  | ||||||
|     // eval the prompt with both models |     // eval the prompt with both models | ||||||
|     llama_eval(ctx_tgt,  inp.data(), int(inp.size() - 1), 0, params.n_threads); |     llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0,           0), params.n_threads); | ||||||
|     llama_eval(ctx_tgt, &inp.back(),      1, inp.size() - 1, params.n_threads); |     llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1, n_input - 1, 0), params.n_threads); | ||||||
|     llama_eval(ctx_dft,  inp.data(),     int(inp.size()), 0, params.n_threads); |     llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input,     0,           0), params.n_threads); | ||||||
|  |  | ||||||
|     const auto t_enc_end = ggml_time_us(); |     const auto t_enc_end = ggml_time_us(); | ||||||
|  |  | ||||||
| @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { | |||||||
|                 LOG("out of drafted tokens\n"); |                 LOG("out of drafted tokens\n"); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); |             llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); | ||||||
|             ++n_past_dft; |             ++n_past_dft; | ||||||
|  |  | ||||||
|             // heuristic for n_draft |             // heuristic for n_draft | ||||||
| @@ -256,7 +256,7 @@ int main(int argc, char ** argv) { | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             // evaluate the drafted token on the draft model |             // evaluate the drafted token on the draft model | ||||||
|             llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); |             llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); | ||||||
|             ++n_past_cur; |             ++n_past_cur; | ||||||
|  |  | ||||||
|             if (grammar_dft != NULL) { |             if (grammar_dft != NULL) { | ||||||
| @@ -265,7 +265,7 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // evaluate the target model on the drafted tokens |         // evaluate the target model on the drafted tokens | ||||||
|         llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); |         llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); | ||||||
|         ++n_past_tgt; |         ++n_past_tgt; | ||||||
|  |  | ||||||
|         // the first token is always proposed by the traget model before the speculation loop |         // the first token is always proposed by the traget model before the speculation loop | ||||||
|   | |||||||
							
								
								
									
										119
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										119
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1265,7 +1265,7 @@ static bool llama_kv_cache_init( | |||||||
| // updates the cache head | // updates the cache head | ||||||
| static bool llama_kv_cache_find_slot( | static bool llama_kv_cache_find_slot( | ||||||
|              struct llama_kv_cache & cache, |              struct llama_kv_cache & cache, | ||||||
|                 struct llama_batch & batch) { |           const struct llama_batch & batch) { | ||||||
|     const uint32_t n_ctx    = cache.size; |     const uint32_t n_ctx    = cache.size; | ||||||
|     const uint32_t n_tokens = batch.n_tokens; |     const uint32_t n_tokens = batch.n_tokens; | ||||||
|  |  | ||||||
| @@ -2522,7 +2522,7 @@ static bool llama_model_load( | |||||||
|  |  | ||||||
| static struct ggml_cgraph * llm_build_llama( | static struct ggml_cgraph * llm_build_llama( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch & batch) { |      const llama_batch & batch) { | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
| @@ -2876,7 +2876,7 @@ static struct ggml_cgraph * llm_build_llama( | |||||||
|  |  | ||||||
| static struct ggml_cgraph * llm_build_baichaun( | static struct ggml_cgraph * llm_build_baichaun( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch & batch) { |      const llama_batch & batch) { | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
| @@ -3247,7 +3247,7 @@ static struct ggml_cgraph * llm_build_baichaun( | |||||||
|  |  | ||||||
| static struct ggml_cgraph * llm_build_falcon( | static struct ggml_cgraph * llm_build_falcon( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch & batch) { |      const llama_batch & batch) { | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
| @@ -3577,7 +3577,7 @@ static struct ggml_cgraph * llm_build_falcon( | |||||||
|  |  | ||||||
| static struct ggml_cgraph * llm_build_starcoder( | static struct ggml_cgraph * llm_build_starcoder( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch & batch) { |      const llama_batch & batch) { | ||||||
|     const auto & model   = lctx.model; |     const auto & model   = lctx.model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
| @@ -3819,7 +3819,7 @@ static struct ggml_cgraph * llm_build_starcoder( | |||||||
|  |  | ||||||
| static struct ggml_cgraph * llama_build_graph( | static struct ggml_cgraph * llama_build_graph( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch & batch) { |      const llama_batch & batch) { | ||||||
|     const auto & model = lctx.model; |     const auto & model = lctx.model; | ||||||
|  |  | ||||||
|     struct ggml_cgraph * result = NULL; |     struct ggml_cgraph * result = NULL; | ||||||
| @@ -3856,7 +3856,7 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
| // | // | ||||||
| static bool llama_eval_internal( | static bool llama_eval_internal( | ||||||
|          llama_context & lctx, |          llama_context & lctx, | ||||||
|            llama_batch & batch, |            llama_batch   batch, | ||||||
|                    int   n_threads) { |                    int   n_threads) { | ||||||
|     const uint32_t n_tokens = batch.n_tokens; |     const uint32_t n_tokens = batch.n_tokens; | ||||||
|  |  | ||||||
| @@ -3886,6 +3886,31 @@ static bool llama_eval_internal( | |||||||
|     const int64_t n_embd  = hparams.n_embd; |     const int64_t n_embd  = hparams.n_embd; | ||||||
|     const int64_t n_vocab = hparams.n_vocab; |     const int64_t n_vocab = hparams.n_vocab; | ||||||
|  |  | ||||||
|  |     std::vector<llama_pos>    pos; | ||||||
|  |     std::vector<llama_seq_id> seq_id; | ||||||
|  |  | ||||||
|  |     if (batch.pos == nullptr) { | ||||||
|  |         pos.resize(n_tokens); | ||||||
|  |         for (uint32_t i = 0; i < n_tokens; i++) { | ||||||
|  |             pos[i] = batch.all_pos_0 + i*batch.all_pos_1; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         batch.pos = pos.data(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (batch.seq_id == nullptr) { | ||||||
|  |         seq_id.resize(n_tokens); | ||||||
|  |         for (uint32_t i = 0; i < n_tokens; i++) { | ||||||
|  |             seq_id[i] = batch.all_seq_id; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         batch.seq_id = seq_id.data(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (batch.clear_kv) { | ||||||
|  |         llama_kv_cache_clear(kv_self, 0, -1); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     if (!llama_kv_cache_find_slot(kv_self, batch)) { |     if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
| @@ -4820,6 +4845,13 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) | |||||||
| // sampling | // sampling | ||||||
| // | // | ||||||
|  |  | ||||||
|  | void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { | ||||||
|  |     if (seed == LLAMA_DEFAULT_SEED) { | ||||||
|  |         seed = time(NULL); | ||||||
|  |     } | ||||||
|  |     ctx->rng.seed(seed); | ||||||
|  | } | ||||||
|  |  | ||||||
| void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { | void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { | ||||||
|     GGML_ASSERT(candidates->size > 0); |     GGML_ASSERT(candidates->size > 0); | ||||||
|  |  | ||||||
| @@ -5469,7 +5501,7 @@ struct llama_beam_search_data { | |||||||
|         } else { |         } else { | ||||||
|             // beam is not at end-of-sentence, so branch with next top_k tokens. |             // beam is not at end-of-sentence, so branch with next top_k tokens. | ||||||
|             if (!beam.tokens.empty()) { |             if (!beam.tokens.empty()) { | ||||||
|                 llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); |                 llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads); | ||||||
|             } |             } | ||||||
|             llama_logit_info logit_info(ctx); |             llama_logit_info logit_info(ctx); | ||||||
|             std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams); |             std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams); | ||||||
| @@ -5543,7 +5575,7 @@ struct llama_beam_search_data { | |||||||
|             callback(callback_data, get_beams_state(false));  // Sets common_prefix_length |             callback(callback_data, get_beams_state(false));  // Sets common_prefix_length | ||||||
|             update_beams_from_beam_views();   // Update values (p,eob) that callback may have changed. |             update_beams_from_beam_views();   // Update values (p,eob) that callback may have changed. | ||||||
|             if (common_prefix_length) { |             if (common_prefix_length) { | ||||||
|                 llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); |                 llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads); | ||||||
|                 n_past += common_prefix_length; |                 n_past += common_prefix_length; | ||||||
|             } |             } | ||||||
|             // Zero-out next_beam probabilities to place them last in following min-heap. |             // Zero-out next_beam probabilities to place them last in following min-heap. | ||||||
| @@ -6505,8 +6537,7 @@ struct llama_context * llama_new_context_with_model( | |||||||
|             // build worst-case graph |             // build worst-case graph | ||||||
|             uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch); |             uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch); | ||||||
|             llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph |             llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph | ||||||
|             llama_batch batch = { n_tokens, &token, nullptr, nullptr, nullptr }; |             ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0)); | ||||||
|             ggml_cgraph * gf = llama_build_graph(*ctx, batch); |  | ||||||
|  |  | ||||||
| #ifdef GGML_USE_METAL | #ifdef GGML_USE_METAL | ||||||
|             if (params.n_gpu_layers > 0) { |             if (params.n_gpu_layers > 0) { | ||||||
| @@ -6714,15 +6745,6 @@ void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) { | |||||||
|     llama_kv_cache_clear(ctx->kv_self, p0, p1); |     llama_kv_cache_clear(ctx->kv_self, p0, p1); | ||||||
| } | } | ||||||
|  |  | ||||||
| #define LLAMA_MAX_RNG_STATE (64*1024) |  | ||||||
|  |  | ||||||
| void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { |  | ||||||
|     if (seed == LLAMA_DEFAULT_SEED) { |  | ||||||
|         seed = time(NULL); |  | ||||||
|     } |  | ||||||
|     ctx->rng.seed(seed); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Returns the *maximum* size of the state | // Returns the *maximum* size of the state | ||||||
| size_t llama_get_state_size(const struct llama_context * ctx) { | size_t llama_get_state_size(const struct llama_context * ctx) { | ||||||
|     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. |     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. | ||||||
| @@ -7116,21 +7138,9 @@ int llama_eval( | |||||||
|                     uint32_t   n_tokens, |                     uint32_t   n_tokens, | ||||||
|                          int   n_past, |                          int   n_past, | ||||||
|                          int   n_threads) { |                          int   n_threads) { | ||||||
|     std::vector<llama_pos> pos(n_tokens); |  | ||||||
|     for (uint32_t i = 0; i < n_tokens; i++) { |  | ||||||
|         pos[i] = n_past + i; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     std::vector<llama_seq_id> seq_id(n_tokens); |  | ||||||
|     for (uint32_t i = 0; i < n_tokens; i++) { |  | ||||||
|         seq_id[i] = 0; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     llama_batch batch = { n_tokens, tokens, nullptr, pos.data(), seq_id.data(), }; |  | ||||||
|  |  | ||||||
|     llama_kv_cache_clear(ctx->kv_self, n_past, -1); |     llama_kv_cache_clear(ctx->kv_self, n_past, -1); | ||||||
|  |  | ||||||
|     if (!llama_eval_internal(*ctx, batch, n_threads)) { |     if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); |         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| @@ -7151,18 +7161,47 @@ int llama_eval_embd( | |||||||
|                         uint32_t   n_tokens, |                         uint32_t   n_tokens, | ||||||
|                              int   n_past, |                              int   n_past, | ||||||
|                              int   n_threads) { |                              int   n_threads) { | ||||||
|     std::vector<llama_pos> pos(n_tokens); |     llama_kv_cache_clear(ctx->kv_self, n_past, -1); | ||||||
|     for (uint32_t i = 0; i < n_tokens; i++) { |  | ||||||
|         pos[i] = n_past + i; |     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, }; | ||||||
|  |  | ||||||
|  |     if (!llama_eval_internal(*ctx, batch, n_threads)) { | ||||||
|  |         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||||
|  |         return 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     std::vector<llama_seq_id> seq_id(n_tokens); |     // get a more accurate load time, upon first eval | ||||||
|     for (uint32_t i = 0; i < n_tokens; i++) { |     // TODO: fix this | ||||||
|         seq_id[i] = 0; |     if (!ctx->has_evaluated_once) { | ||||||
|  |         ctx->t_load_us = ggml_time_us() - ctx->t_start_us; | ||||||
|  |         ctx->has_evaluated_once = true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch batch = { n_tokens, nullptr, embd, pos.data(), seq_id.data(), }; |     return 0; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct llama_batch llama_batch_get_one( | ||||||
|  |        const llama_token * tokens, | ||||||
|  |                 uint32_t   n_tokens, | ||||||
|  |                llama_pos   pos_0, | ||||||
|  |             llama_seq_id   seq_id) { | ||||||
|  |     return { | ||||||
|  |         /*n_tokens    =*/ n_tokens, | ||||||
|  |         /*tokens      =*/ tokens, | ||||||
|  |         /*embd        =*/ nullptr, | ||||||
|  |         /*pos         =*/ nullptr, | ||||||
|  |         /*seq_id      =*/ nullptr, | ||||||
|  |         /*all_pos_0   =*/ pos_0, | ||||||
|  |         /*all_pos_1   =*/ 1, | ||||||
|  |         /*all_seq_id  =*/ seq_id, | ||||||
|  |         /*clear_kv    =*/ pos_0 == 0, | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int llama_decode( | ||||||
|  |         struct llama_context * ctx, | ||||||
|  |           struct llama_batch   batch, | ||||||
|  |                          int   n_threads) { | ||||||
|     if (!llama_eval_internal(*ctx, batch, n_threads)) { |     if (!llama_eval_internal(*ctx, batch, n_threads)) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); |         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
|   | |||||||
							
								
								
									
										43
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								llama.h
									
									
									
									
									
								
							| @@ -37,6 +37,8 @@ | |||||||
|  |  | ||||||
| #define LLAMA_DEFAULT_SEED 0xFFFFFFFF | #define LLAMA_DEFAULT_SEED 0xFFFFFFFF | ||||||
|  |  | ||||||
|  | #define LLAMA_MAX_RNG_STATE (64*1024) | ||||||
|  |  | ||||||
| #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' | #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' | ||||||
|  |  | ||||||
| #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN | #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN | ||||||
| @@ -73,6 +75,17 @@ extern "C" { | |||||||
|         const float        * embd; |         const float        * embd; | ||||||
|         const llama_pos    * pos; |         const llama_pos    * pos; | ||||||
|         const llama_seq_id * seq_id; |         const llama_seq_id * seq_id; | ||||||
|  |  | ||||||
|  |         // NOTE: helpers for smooth API transition - can be deprecated in the future | ||||||
|  |         //       for future-proof code, use the above fields instead and ignore everything below | ||||||
|  |         // | ||||||
|  |         // pos[i] = all_pos_0 + i*all_pos_1 | ||||||
|  |         // | ||||||
|  |         llama_pos    all_pos_0;  // used if pos == NULL | ||||||
|  |         llama_pos    all_pos_1;  // used if pos == NULL | ||||||
|  |         llama_seq_id all_seq_id; // used if seq_id == NULL | ||||||
|  |  | ||||||
|  |         bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations | ||||||
|     } llama_seq; |     } llama_seq; | ||||||
|  |  | ||||||
|     enum llama_log_level { |     enum llama_log_level { | ||||||
| @@ -312,9 +325,6 @@ extern "C" { | |||||||
|  |  | ||||||
|     LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); |     LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); | ||||||
|  |  | ||||||
|     // Sets the current rng seed. |  | ||||||
|     LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); |  | ||||||
|  |  | ||||||
|     // Returns the maximum size in bytes of the state (rng, logits, embedding |     // Returns the maximum size in bytes of the state (rng, logits, embedding | ||||||
|     // and kv_cache) - will often be smaller after compacting tokens |     // and kv_cache) - will often be smaller after compacting tokens | ||||||
|     LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); |     LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); | ||||||
| @@ -336,19 +346,37 @@ extern "C" { | |||||||
|     // tokens + n_tokens is the provided batch of new tokens to process |     // tokens + n_tokens is the provided batch of new tokens to process | ||||||
|     // n_past is the number of tokens to use from previous eval calls |     // n_past is the number of tokens to use from previous eval calls | ||||||
|     // Returns 0 on success |     // Returns 0 on success | ||||||
|     LLAMA_API int llama_eval( |     LLAMA_API DEPRECATED(int llama_eval( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                const llama_token * tokens, |                const llama_token * tokens, | ||||||
|                         uint32_t   n_tokens, |                         uint32_t   n_tokens, | ||||||
|                              int   n_past, |                              int   n_past, | ||||||
|                              int   n_threads); |                              int   n_threads), | ||||||
|  |             "please use llama_decode() instead"); | ||||||
|  |  | ||||||
|     // Same as llama_eval, but use float matrix input directly. |     // Same as llama_eval, but use float matrix input directly. | ||||||
|     LLAMA_API int llama_eval_embd( |     LLAMA_API DEPRECATED(int llama_eval_embd( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                      const float * embd, |                      const float * embd, | ||||||
|                         uint32_t   n_tokens, |                         uint32_t   n_tokens, | ||||||
|                              int   n_past, |                              int   n_past, | ||||||
|  |                              int   n_threads), | ||||||
|  |             "please use llama_decode() instead"); | ||||||
|  |  | ||||||
|  |     // Return batch for single sequence of tokens starting at pos_0 | ||||||
|  |     // If pos_0 == 0, the clear_kv flag will be auto set to true | ||||||
|  |     // | ||||||
|  |     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it | ||||||
|  |     // | ||||||
|  |     LLAMA_API struct llama_batch llama_batch_get_one( | ||||||
|  |             const llama_token * tokens, | ||||||
|  |                      uint32_t   n_tokens, | ||||||
|  |                     llama_pos   pos_0, | ||||||
|  |                  llama_seq_id   seq_id); | ||||||
|  |  | ||||||
|  |     LLAMA_API int llama_decode( | ||||||
|  |             struct llama_context * ctx, | ||||||
|  |               struct llama_batch   batch, | ||||||
|                              int   n_threads); |                              int   n_threads); | ||||||
|  |  | ||||||
|     // Token logits obtained from the last call to llama_eval() |     // Token logits obtained from the last call to llama_eval() | ||||||
| @@ -434,6 +462,9 @@ extern "C" { | |||||||
|     // Sampling functions |     // Sampling functions | ||||||
|     // |     // | ||||||
|  |  | ||||||
|  |     // Sets the current rng seed. | ||||||
|  |     LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); | ||||||
|  |  | ||||||
|     /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. |     /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||||
|     LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); |     LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov