mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : add new llama_decode() API that works with llama_batch
This commit is contained in:
		| @@ -780,7 +780,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par | ||||
|         LOG("warming up the model with an empty run\n"); | ||||
|  | ||||
|         std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; | ||||
|         llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); | ||||
|         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); | ||||
|         llama_reset_timings(lctx); | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -160,7 +160,7 @@ int main(int argc, char ** argv) | ||||
|  | ||||
|     int n_past = 0; | ||||
|  | ||||
|     if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) | ||||
|     if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads)) | ||||
|     { | ||||
|         fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); | ||||
|         return 1; | ||||
|   | ||||
| @@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){ | ||||
|         if (n_eval > n_batch) { | ||||
|             n_eval = n_batch; | ||||
|         } | ||||
|         if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { | ||||
|         llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false }; | ||||
|         if (llama_decode(ctx, batch, params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|             return false; | ||||
|         } | ||||
| @@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) { | ||||
|         if (n_eval > params.n_batch) { | ||||
|             n_eval = params.n_batch; | ||||
|         } | ||||
|         if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { | ||||
|         if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|             return false; | ||||
|         } | ||||
|   | ||||
| @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     while (!embd_inp.empty()) { | ||||
|         int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); | ||||
|         if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) { | ||||
|         if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
|   | ||||
| @@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat | ||||
|     int n_processed = 0; | ||||
|     while (n_processed < n_prompt) { | ||||
|         int n_tokens = std::min(n_prompt - n_processed, n_batch); | ||||
|         llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); | ||||
|         llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads); | ||||
|         n_processed += n_tokens; | ||||
|     } | ||||
| } | ||||
| @@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat | ||||
| static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { | ||||
|     llama_token token = llama_token_bos(ctx); | ||||
|     for (int i = 0; i < n_gen; i++) { | ||||
|         llama_eval(ctx, &token, 1, n_past + i, n_threads); | ||||
|         llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads); | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -571,7 +571,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 for (int i = 0; i < input_size; i += params.n_batch) { | ||||
|                     int n_eval = std::min(input_size - i, params.n_batch); | ||||
|                     if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { | ||||
|                     if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) { | ||||
|                         LOG_TEE("%s : failed to eval\n", __func__); | ||||
|                         return 1; | ||||
|                     } | ||||
| @@ -588,7 +588,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); | ||||
|  | ||||
|                 if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { | ||||
|                 if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) { | ||||
|                     LOG_TEE("%s : failed to eval\n", __func__); | ||||
|                     return 1; | ||||
|                 } | ||||
|   | ||||
| @@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & | ||||
|             const int batch_size  = std::min(end - batch_start, n_batch); | ||||
|  | ||||
|             //fprintf(stderr, "    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); | ||||
|             if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { | ||||
|             if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { | ||||
|                 //fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|                 return {tokens, -1, logit_history, prob_history}; | ||||
|             } | ||||
| @@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par | ||||
|                 tokens[batch_start] = llama_token_bos(ctx); | ||||
|             } | ||||
|  | ||||
|             if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { | ||||
|             if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { | ||||
|                 fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|                 return {tokens, -1, logit_history, prob_history}; | ||||
|             } | ||||
| @@ -409,7 +409,7 @@ static std::vector<float> hellaswag_evaluate_tokens( | ||||
|     for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { | ||||
|         size_t n_tokens = tokens.size() - i_chunk * n_batch; | ||||
|         n_tokens = std::min(n_tokens, size_t(n_batch)); | ||||
|         if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { | ||||
|         if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) { | ||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|             return {}; | ||||
|         } | ||||
|   | ||||
| @@ -34,11 +34,11 @@ int main(int argc, char ** argv) { | ||||
|     auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); | ||||
|  | ||||
|     // init | ||||
|     auto model = llama_load_model_from_file(params.model.c_str(), lparams); | ||||
|     auto * model = llama_load_model_from_file(params.model.c_str(), lparams); | ||||
|     if (model == nullptr) { | ||||
|         return 1; | ||||
|     } | ||||
|     auto ctx = llama_new_context_with_model(model, lparams); | ||||
|     auto * ctx = llama_new_context_with_model(model, lparams); | ||||
|     if (ctx == nullptr) { | ||||
|         llama_free_model(model); | ||||
|         return 1; | ||||
| @@ -53,7 +53,7 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     // evaluate prompt | ||||
|     llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); | ||||
|     llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); | ||||
|  | ||||
|     last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); | ||||
|     n_past += n_prompt_tokens; | ||||
| @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { | ||||
|     printf("\n%s", params.prompt.c_str()); | ||||
|  | ||||
|     for (auto i = 0; i < params.n_predict; i++) { | ||||
|         auto logits = llama_get_logits(ctx); | ||||
|         auto * logits = llama_get_logits(ctx); | ||||
|         auto n_vocab = llama_n_vocab(ctx); | ||||
|         std::vector<llama_token_data> candidates; | ||||
|         candidates.reserve(n_vocab); | ||||
| @@ -90,7 +90,7 @@ int main(int argc, char ** argv) { | ||||
|         last_n_tokens_data.push_back(next_token); | ||||
|  | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { | ||||
|         if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_free(ctx); | ||||
|             llama_free_model(model); | ||||
| @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { | ||||
|     llama_free(ctx); | ||||
|  | ||||
|     // make new context | ||||
|     auto ctx2 = llama_new_context_with_model(model, lparams); | ||||
|     auto * ctx2 = llama_new_context_with_model(model, lparams); | ||||
|  | ||||
|     // Load state (rng, logits, embedding and kv_cache) from file | ||||
|     { | ||||
| @@ -137,7 +137,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // second run | ||||
|     for (auto i = 0; i < params.n_predict; i++) { | ||||
|         auto logits = llama_get_logits(ctx2); | ||||
|         auto * logits = llama_get_logits(ctx2); | ||||
|         auto n_vocab = llama_n_vocab(ctx2); | ||||
|         std::vector<llama_token_data> candidates; | ||||
|         candidates.reserve(n_vocab); | ||||
| @@ -150,7 +150,7 @@ int main(int argc, char ** argv) { | ||||
|         last_n_tokens_data.push_back(next_token); | ||||
|  | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { | ||||
|         if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_free(ctx2); | ||||
|             llama_free_model(model); | ||||
|   | ||||
| @@ -434,7 +434,7 @@ struct llama_server_context | ||||
|             { | ||||
|                 n_eval = params.n_batch; | ||||
|             } | ||||
|             if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) | ||||
|             if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) | ||||
|             { | ||||
|                 LOG_ERROR("failed to eval", { | ||||
|                                                 {"n_eval", n_eval}, | ||||
|   | ||||
| @@ -76,7 +76,7 @@ int main(int argc, char ** argv) { | ||||
|     while (n_cur < n_gen) { | ||||
|         // evaluate the transformer | ||||
|  | ||||
|         if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) { | ||||
|         if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
|   | ||||
| @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { | ||||
|     const auto t_enc_start = ggml_time_us(); | ||||
|  | ||||
|     // eval the prompt with both models | ||||
|     llama_eval(ctx_tgt,  inp.data(), int(inp.size() - 1), 0, params.n_threads); | ||||
|     llama_eval(ctx_tgt, &inp.back(),      1, inp.size() - 1, params.n_threads); | ||||
|     llama_eval(ctx_dft,  inp.data(),     int(inp.size()), 0, params.n_threads); | ||||
|     llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0,           0), params.n_threads); | ||||
|     llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1, n_input - 1, 0), params.n_threads); | ||||
|     llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input,     0,           0), params.n_threads); | ||||
|  | ||||
|     const auto t_enc_end = ggml_time_us(); | ||||
|  | ||||
| @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { | ||||
|                 LOG("out of drafted tokens\n"); | ||||
|             } | ||||
|  | ||||
|             llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); | ||||
|             llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); | ||||
|             ++n_past_dft; | ||||
|  | ||||
|             // heuristic for n_draft | ||||
| @@ -256,7 +256,7 @@ int main(int argc, char ** argv) { | ||||
|             } | ||||
|  | ||||
|             // evaluate the drafted token on the draft model | ||||
|             llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); | ||||
|             llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); | ||||
|             ++n_past_cur; | ||||
|  | ||||
|             if (grammar_dft != NULL) { | ||||
| @@ -265,7 +265,7 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // evaluate the target model on the drafted tokens | ||||
|         llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); | ||||
|         llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); | ||||
|         ++n_past_tgt; | ||||
|  | ||||
|         // the first token is always proposed by the traget model before the speculation loop | ||||
|   | ||||
							
								
								
									
										119
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										119
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1265,7 +1265,7 @@ static bool llama_kv_cache_init( | ||||
| // updates the cache head | ||||
| static bool llama_kv_cache_find_slot( | ||||
|              struct llama_kv_cache & cache, | ||||
|                 struct llama_batch & batch) { | ||||
|           const struct llama_batch & batch) { | ||||
|     const uint32_t n_ctx    = cache.size; | ||||
|     const uint32_t n_tokens = batch.n_tokens; | ||||
|  | ||||
| @@ -2522,7 +2522,7 @@ static bool llama_model_load( | ||||
|  | ||||
| static struct ggml_cgraph * llm_build_llama( | ||||
|          llama_context & lctx, | ||||
|            llama_batch & batch) { | ||||
|      const llama_batch & batch) { | ||||
|     const auto & model   = lctx.model; | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
| @@ -2876,7 +2876,7 @@ static struct ggml_cgraph * llm_build_llama( | ||||
|  | ||||
| static struct ggml_cgraph * llm_build_baichaun( | ||||
|          llama_context & lctx, | ||||
|            llama_batch & batch) { | ||||
|      const llama_batch & batch) { | ||||
|     const auto & model   = lctx.model; | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
| @@ -3247,7 +3247,7 @@ static struct ggml_cgraph * llm_build_baichaun( | ||||
|  | ||||
| static struct ggml_cgraph * llm_build_falcon( | ||||
|          llama_context & lctx, | ||||
|            llama_batch & batch) { | ||||
|      const llama_batch & batch) { | ||||
|     const auto & model   = lctx.model; | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
| @@ -3577,7 +3577,7 @@ static struct ggml_cgraph * llm_build_falcon( | ||||
|  | ||||
| static struct ggml_cgraph * llm_build_starcoder( | ||||
|          llama_context & lctx, | ||||
|            llama_batch & batch) { | ||||
|      const llama_batch & batch) { | ||||
|     const auto & model   = lctx.model; | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
| @@ -3819,7 +3819,7 @@ static struct ggml_cgraph * llm_build_starcoder( | ||||
|  | ||||
| static struct ggml_cgraph * llama_build_graph( | ||||
|          llama_context & lctx, | ||||
|            llama_batch & batch) { | ||||
|      const llama_batch & batch) { | ||||
|     const auto & model = lctx.model; | ||||
|  | ||||
|     struct ggml_cgraph * result = NULL; | ||||
| @@ -3856,7 +3856,7 @@ static struct ggml_cgraph * llama_build_graph( | ||||
| // | ||||
| static bool llama_eval_internal( | ||||
|          llama_context & lctx, | ||||
|            llama_batch & batch, | ||||
|            llama_batch   batch, | ||||
|                    int   n_threads) { | ||||
|     const uint32_t n_tokens = batch.n_tokens; | ||||
|  | ||||
| @@ -3886,6 +3886,31 @@ static bool llama_eval_internal( | ||||
|     const int64_t n_embd  = hparams.n_embd; | ||||
|     const int64_t n_vocab = hparams.n_vocab; | ||||
|  | ||||
|     std::vector<llama_pos>    pos; | ||||
|     std::vector<llama_seq_id> seq_id; | ||||
|  | ||||
|     if (batch.pos == nullptr) { | ||||
|         pos.resize(n_tokens); | ||||
|         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|             pos[i] = batch.all_pos_0 + i*batch.all_pos_1; | ||||
|         } | ||||
|  | ||||
|         batch.pos = pos.data(); | ||||
|     } | ||||
|  | ||||
|     if (batch.seq_id == nullptr) { | ||||
|         seq_id.resize(n_tokens); | ||||
|         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|             seq_id[i] = batch.all_seq_id; | ||||
|         } | ||||
|  | ||||
|         batch.seq_id = seq_id.data(); | ||||
|     } | ||||
|  | ||||
|     if (batch.clear_kv) { | ||||
|         llama_kv_cache_clear(kv_self, 0, -1); | ||||
|     } | ||||
|  | ||||
|     if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||
|         return false; | ||||
|     } | ||||
| @@ -4820,6 +4845,13 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) | ||||
| // sampling | ||||
| // | ||||
|  | ||||
| void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { | ||||
|     if (seed == LLAMA_DEFAULT_SEED) { | ||||
|         seed = time(NULL); | ||||
|     } | ||||
|     ctx->rng.seed(seed); | ||||
| } | ||||
|  | ||||
| void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { | ||||
|     GGML_ASSERT(candidates->size > 0); | ||||
|  | ||||
| @@ -5469,7 +5501,7 @@ struct llama_beam_search_data { | ||||
|         } else { | ||||
|             // beam is not at end-of-sentence, so branch with next top_k tokens. | ||||
|             if (!beam.tokens.empty()) { | ||||
|                 llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); | ||||
|                 llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads); | ||||
|             } | ||||
|             llama_logit_info logit_info(ctx); | ||||
|             std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams); | ||||
| @@ -5543,7 +5575,7 @@ struct llama_beam_search_data { | ||||
|             callback(callback_data, get_beams_state(false));  // Sets common_prefix_length | ||||
|             update_beams_from_beam_views();   // Update values (p,eob) that callback may have changed. | ||||
|             if (common_prefix_length) { | ||||
|                 llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); | ||||
|                 llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads); | ||||
|                 n_past += common_prefix_length; | ||||
|             } | ||||
|             // Zero-out next_beam probabilities to place them last in following min-heap. | ||||
| @@ -6505,8 +6537,7 @@ struct llama_context * llama_new_context_with_model( | ||||
|             // build worst-case graph | ||||
|             uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch); | ||||
|             llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph | ||||
|             llama_batch batch = { n_tokens, &token, nullptr, nullptr, nullptr }; | ||||
|             ggml_cgraph * gf = llama_build_graph(*ctx, batch); | ||||
|             ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0)); | ||||
|  | ||||
| #ifdef GGML_USE_METAL | ||||
|             if (params.n_gpu_layers > 0) { | ||||
| @@ -6714,15 +6745,6 @@ void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) { | ||||
|     llama_kv_cache_clear(ctx->kv_self, p0, p1); | ||||
| } | ||||
|  | ||||
| #define LLAMA_MAX_RNG_STATE (64*1024) | ||||
|  | ||||
| void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { | ||||
|     if (seed == LLAMA_DEFAULT_SEED) { | ||||
|         seed = time(NULL); | ||||
|     } | ||||
|     ctx->rng.seed(seed); | ||||
| } | ||||
|  | ||||
| // Returns the *maximum* size of the state | ||||
| size_t llama_get_state_size(const struct llama_context * ctx) { | ||||
|     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. | ||||
| @@ -7116,21 +7138,9 @@ int llama_eval( | ||||
|                     uint32_t   n_tokens, | ||||
|                          int   n_past, | ||||
|                          int   n_threads) { | ||||
|     std::vector<llama_pos> pos(n_tokens); | ||||
|     for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|         pos[i] = n_past + i; | ||||
|     } | ||||
|  | ||||
|     std::vector<llama_seq_id> seq_id(n_tokens); | ||||
|     for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|         seq_id[i] = 0; | ||||
|     } | ||||
|  | ||||
|     llama_batch batch = { n_tokens, tokens, nullptr, pos.data(), seq_id.data(), }; | ||||
|  | ||||
|     llama_kv_cache_clear(ctx->kv_self, n_past, -1); | ||||
|  | ||||
|     if (!llama_eval_internal(*ctx, batch, n_threads)) { | ||||
|     if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -7151,18 +7161,47 @@ int llama_eval_embd( | ||||
|                         uint32_t   n_tokens, | ||||
|                              int   n_past, | ||||
|                              int   n_threads) { | ||||
|     std::vector<llama_pos> pos(n_tokens); | ||||
|     for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|         pos[i] = n_past + i; | ||||
|     llama_kv_cache_clear(ctx->kv_self, n_past, -1); | ||||
|  | ||||
|     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, }; | ||||
|  | ||||
|     if (!llama_eval_internal(*ctx, batch, n_threads)) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     std::vector<llama_seq_id> seq_id(n_tokens); | ||||
|     for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|         seq_id[i] = 0; | ||||
|     // get a more accurate load time, upon first eval | ||||
|     // TODO: fix this | ||||
|     if (!ctx->has_evaluated_once) { | ||||
|         ctx->t_load_us = ggml_time_us() - ctx->t_start_us; | ||||
|         ctx->has_evaluated_once = true; | ||||
|     } | ||||
|  | ||||
|     llama_batch batch = { n_tokens, nullptr, embd, pos.data(), seq_id.data(), }; | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| struct llama_batch llama_batch_get_one( | ||||
|        const llama_token * tokens, | ||||
|                 uint32_t   n_tokens, | ||||
|                llama_pos   pos_0, | ||||
|             llama_seq_id   seq_id) { | ||||
|     return { | ||||
|         /*n_tokens    =*/ n_tokens, | ||||
|         /*tokens      =*/ tokens, | ||||
|         /*embd        =*/ nullptr, | ||||
|         /*pos         =*/ nullptr, | ||||
|         /*seq_id      =*/ nullptr, | ||||
|         /*all_pos_0   =*/ pos_0, | ||||
|         /*all_pos_1   =*/ 1, | ||||
|         /*all_seq_id  =*/ seq_id, | ||||
|         /*clear_kv    =*/ pos_0 == 0, | ||||
|     }; | ||||
| } | ||||
|  | ||||
| int llama_decode( | ||||
|         struct llama_context * ctx, | ||||
|           struct llama_batch   batch, | ||||
|                          int   n_threads) { | ||||
|     if (!llama_eval_internal(*ctx, batch, n_threads)) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); | ||||
|         return 1; | ||||
|   | ||||
							
								
								
									
										45
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								llama.h
									
									
									
									
									
								
							| @@ -37,6 +37,8 @@ | ||||
|  | ||||
| #define LLAMA_DEFAULT_SEED 0xFFFFFFFF | ||||
|  | ||||
| #define LLAMA_MAX_RNG_STATE (64*1024) | ||||
|  | ||||
| #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' | ||||
|  | ||||
| #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN | ||||
| @@ -70,9 +72,20 @@ extern "C" { | ||||
|  | ||||
|         // TODO: not sure about these consts - might just get in the way all the time with no benefit | ||||
|         const llama_token  * token; | ||||
|         const       float  * embd; | ||||
|         const float        * embd; | ||||
|         const llama_pos    * pos; | ||||
|         const llama_seq_id * seq_id; | ||||
|  | ||||
|         // NOTE: helpers for smooth API transition - can be deprecated in the future | ||||
|         //       for future-proof code, use the above fields instead and ignore everything below | ||||
|         // | ||||
|         // pos[i] = all_pos_0 + i*all_pos_1 | ||||
|         // | ||||
|         llama_pos    all_pos_0;  // used if pos == NULL | ||||
|         llama_pos    all_pos_1;  // used if pos == NULL | ||||
|         llama_seq_id all_seq_id; // used if seq_id == NULL | ||||
|  | ||||
|         bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations | ||||
|     } llama_seq; | ||||
|  | ||||
|     enum llama_log_level { | ||||
| @@ -312,9 +325,6 @@ extern "C" { | ||||
|  | ||||
|     LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); | ||||
|  | ||||
|     // Sets the current rng seed. | ||||
|     LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); | ||||
|  | ||||
|     // Returns the maximum size in bytes of the state (rng, logits, embedding | ||||
|     // and kv_cache) - will often be smaller after compacting tokens | ||||
|     LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); | ||||
| @@ -336,19 +346,37 @@ extern "C" { | ||||
|     // tokens + n_tokens is the provided batch of new tokens to process | ||||
|     // n_past is the number of tokens to use from previous eval calls | ||||
|     // Returns 0 on success | ||||
|     LLAMA_API int llama_eval( | ||||
|     LLAMA_API DEPRECATED(int llama_eval( | ||||
|             struct llama_context * ctx, | ||||
|                const llama_token * tokens, | ||||
|                         uint32_t   n_tokens, | ||||
|                              int   n_past, | ||||
|                              int   n_threads); | ||||
|                              int   n_threads), | ||||
|             "please use llama_decode() instead"); | ||||
|  | ||||
|     // Same as llama_eval, but use float matrix input directly. | ||||
|     LLAMA_API int llama_eval_embd( | ||||
|     LLAMA_API DEPRECATED(int llama_eval_embd( | ||||
|             struct llama_context * ctx, | ||||
|                      const float * embd, | ||||
|                         uint32_t   n_tokens, | ||||
|                              int   n_past, | ||||
|                              int   n_threads), | ||||
|             "please use llama_decode() instead"); | ||||
|  | ||||
|     // Return batch for single sequence of tokens starting at pos_0 | ||||
|     // If pos_0 == 0, the clear_kv flag will be auto set to true | ||||
|     // | ||||
|     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it | ||||
|     // | ||||
|     LLAMA_API struct llama_batch llama_batch_get_one( | ||||
|             const llama_token * tokens, | ||||
|                      uint32_t   n_tokens, | ||||
|                     llama_pos   pos_0, | ||||
|                  llama_seq_id   seq_id); | ||||
|  | ||||
|     LLAMA_API int llama_decode( | ||||
|             struct llama_context * ctx, | ||||
|               struct llama_batch   batch, | ||||
|                              int   n_threads); | ||||
|  | ||||
|     // Token logits obtained from the last call to llama_eval() | ||||
| @@ -434,6 +462,9 @@ extern "C" { | ||||
|     // Sampling functions | ||||
|     // | ||||
|  | ||||
|     // Sets the current rng seed. | ||||
|     LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); | ||||
|  | ||||
|     /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||
|     LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov