mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	rework, targeting llama-server
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -98,6 +98,7 @@ examples/server/*.css.hpp | ||||
| examples/server/*.html.hpp | ||||
| examples/server/*.js.hpp | ||||
| examples/server/*.mjs.hpp | ||||
| examples/server/*.gz.hpp | ||||
| !build_64.sh | ||||
| !examples/*.bat | ||||
| !examples/*/*.kts | ||||
|   | ||||
| @@ -580,6 +580,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam | ||||
|     return buf.str(); | ||||
| } | ||||
|  | ||||
| /* | ||||
| std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { | ||||
|     std::stringstream buf; | ||||
|  | ||||
| @@ -614,6 +615,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat | ||||
|  | ||||
|     return buf.str(); | ||||
| } | ||||
| */ | ||||
|  | ||||
| void string_process_escapes(std::string & input) { | ||||
|     std::size_t input_len = input.length(); | ||||
| @@ -1608,27 +1610,20 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons | ||||
| // Batch utils | ||||
| // | ||||
|  | ||||
| void common_batch_clear(struct llama_batch & batch) { | ||||
|     batch.n_tokens = 0; | ||||
| void common_batch_clear(struct llama_batch * batch) { | ||||
|     llama_batch_clear(batch); | ||||
| } | ||||
|  | ||||
| void common_batch_add( | ||||
|                  struct llama_batch & batch, | ||||
|                  struct llama_batch * batch, | ||||
|                         llama_token   id, | ||||
|                           llama_pos   pos, | ||||
|     const std::vector<llama_seq_id> & seq_ids, | ||||
|                                bool   logits) { | ||||
|     GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); | ||||
|  | ||||
|     batch.token   [batch.n_tokens] = id; | ||||
|     batch.pos     [batch.n_tokens] = pos; | ||||
|     batch.n_seq_id[batch.n_tokens] = seq_ids.size(); | ||||
|     for (size_t i = 0; i < seq_ids.size(); ++i) { | ||||
|         batch.seq_id[batch.n_tokens][i] = seq_ids[i]; | ||||
|     int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits); | ||||
|     if (res == -1) { | ||||
|         LOG_ERR("%s: llama_batch size exceeded\n", __func__); | ||||
|     } | ||||
|     batch.logits  [batch.n_tokens] = logits; | ||||
|  | ||||
|     batch.n_tokens++; | ||||
| } | ||||
|  | ||||
| // | ||||
|   | ||||
| @@ -554,10 +554,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap | ||||
| // Batch utils | ||||
| // | ||||
|  | ||||
| void common_batch_clear(struct llama_batch & batch); | ||||
| void common_batch_clear(struct llama_batch * batch); | ||||
|  | ||||
| void common_batch_add( | ||||
|                  struct llama_batch & batch, | ||||
|                  struct llama_batch * batch, | ||||
|                         llama_token   id, | ||||
|                           llama_pos   pos, | ||||
|     const std::vector<llama_seq_id> & seq_ids, | ||||
|   | ||||
| @@ -13,7 +13,7 @@ struct common_speculative { | ||||
|     struct llama_context * ctx; | ||||
|     struct common_sampler * smpl; | ||||
|  | ||||
|     llama_batch batch; | ||||
|     llama_batch * batch; | ||||
|     llama_tokens prompt; | ||||
| }; | ||||
|  | ||||
| @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( | ||||
|     auto * result = new common_speculative { | ||||
|         /* .ctx    = */ ctx_dft, | ||||
|         /* .smpl   = */ nullptr, | ||||
|         /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), | ||||
|         /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 1), | ||||
|         /* .prompt = */ {}, | ||||
|     }; | ||||
|  | ||||
| @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft( | ||||
|     } | ||||
|  | ||||
|     // we should rarely end-up here during normal decoding | ||||
|     if (batch.n_tokens > 0) { | ||||
|     if (llama_batch_get_n_tokens(batch) > 0) { | ||||
|         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); | ||||
|  | ||||
|         llama_decode(ctx, batch); | ||||
|   | ||||
| @@ -1215,7 +1215,7 @@ struct server_slot { | ||||
|     // only used for completion/embedding/infill/rerank | ||||
|     server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; | ||||
|  | ||||
|     llama_batch batch_spec = {}; | ||||
|     llama_batch_ptr batch_spec; | ||||
|  | ||||
|     llama_context * ctx = nullptr; | ||||
|     llama_context * ctx_dft = nullptr; | ||||
| @@ -1787,7 +1787,7 @@ struct server_context { | ||||
|  | ||||
|     llama_context_params cparams_dft; | ||||
|  | ||||
|     llama_batch batch = {}; | ||||
|     llama_batch_ptr batch; | ||||
|  | ||||
|     bool clean_kv_cache = true; | ||||
|     bool add_bos_token  = true; | ||||
| @@ -1820,11 +1820,7 @@ struct server_context { | ||||
|  | ||||
|             common_speculative_free(slot.spec); | ||||
|             slot.spec = nullptr; | ||||
|  | ||||
|             llama_batch_free(slot.batch_spec); | ||||
|         } | ||||
|  | ||||
|         llama_batch_free(batch); | ||||
|     } | ||||
|  | ||||
|     bool load_model(const common_params & params) { | ||||
| @@ -1944,7 +1940,7 @@ struct server_context { | ||||
|             slot.n_predict = params_base.n_predict; | ||||
|  | ||||
|             if (model_dft) { | ||||
|                 slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); | ||||
|                 slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1)); | ||||
|  | ||||
|                 slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); | ||||
|                 if (slot.ctx_dft == nullptr) { | ||||
| @@ -1969,7 +1965,7 @@ struct server_context { | ||||
|  | ||||
|             slot.reset(); | ||||
|  | ||||
|             slots.push_back(slot); | ||||
|             slots.push_back(std::move(slot)); | ||||
|         } | ||||
|  | ||||
|         default_generation_settings_for_props = slots[0].to_json(); | ||||
| @@ -1980,7 +1976,7 @@ struct server_context { | ||||
|             const int32_t n_batch = llama_n_batch(ctx); | ||||
|  | ||||
|             // only a single seq_id per token is needed | ||||
|             batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); | ||||
|             batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1)); | ||||
|         } | ||||
|  | ||||
|         metrics.init(); | ||||
| @@ -2098,9 +2094,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         if (slot.ctx_dft) { | ||||
|             llama_batch_free(slot.batch_spec); | ||||
|  | ||||
|             slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); | ||||
|             slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1)); | ||||
|         } | ||||
|  | ||||
|         slot.state = SLOT_STATE_STARTED; | ||||
| @@ -2408,7 +2402,7 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_embedding(const server_slot & slot, const llama_batch & batch) { | ||||
|     void send_embedding(const server_slot & slot, llama_batch_ptr & batch) { | ||||
|         auto res = std::make_unique<server_task_result_embd>(); | ||||
|         res->id        = slot.id_task; | ||||
|         res->index     = slot.index; | ||||
| @@ -2419,18 +2413,19 @@ struct server_context { | ||||
|  | ||||
|         std::vector<float> embd_res(n_embd, 0.0f); | ||||
|  | ||||
|         for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { | ||||
|         for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { | ||||
|             llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); | ||||
|             if (!tok.logits || tok.seq_id[0] != slot.id) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); | ||||
|             if (embd == NULL) { | ||||
|                 embd = llama_get_embeddings_ith(ctx, i); | ||||
|             } | ||||
|  | ||||
|             if (embd == NULL) { | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); | ||||
|  | ||||
|                 res->embedding.push_back(std::vector<float>(n_embd, 0.0f)); | ||||
|                 continue; | ||||
| @@ -2451,24 +2446,25 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_rerank(const server_slot & slot, const llama_batch & batch) { | ||||
|     void send_rerank(const server_slot & slot, llama_batch_ptr & batch) { | ||||
|         auto res = std::make_unique<server_task_result_rerank>(); | ||||
|         res->id    = slot.id_task; | ||||
|         res->index = slot.index; | ||||
|         res->n_tokens = slot.n_prompt_tokens; | ||||
|  | ||||
|         for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { | ||||
|         for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { | ||||
|             llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); | ||||
|             if (!tok.logits || tok.seq_id[0] != slot.id) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); | ||||
|             if (embd == NULL) { | ||||
|                 embd = llama_get_embeddings_ith(ctx, i); | ||||
|             } | ||||
|  | ||||
|             if (embd == NULL) { | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); | ||||
|  | ||||
|                 res->score = -1e6; | ||||
|                 continue; | ||||
| @@ -2859,7 +2855,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         // start populating the batch for this iteration | ||||
|         common_batch_clear(batch); | ||||
|         common_batch_clear(batch.get()); | ||||
|  | ||||
|         // track if given slot can be batched with slots already in the batch | ||||
|         server_slot * slot_batched = nullptr; | ||||
| @@ -2881,9 +2877,9 @@ struct server_context { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             slot.i_batch = batch.n_tokens; | ||||
|             slot.i_batch = llama_batch_get_n_tokens(batch.get()); | ||||
|  | ||||
|             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); | ||||
|             common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true); | ||||
|  | ||||
|             slot.n_past += 1; | ||||
|  | ||||
| @@ -2900,7 +2896,7 @@ struct server_context { | ||||
|         int32_t n_ubatch = llama_n_ubatch(ctx); | ||||
|  | ||||
|         // next, batch any pending prompts without exceeding n_batch | ||||
|         if (params_base.cont_batching || batch.n_tokens == 0) { | ||||
|         if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) { | ||||
|             for (auto & slot : slots) { | ||||
|                 // check if we can batch this slot with the previous one | ||||
|                 if (slot.is_processing()) { | ||||
| @@ -3066,7 +3062,7 @@ struct server_context { | ||||
|                     // non-causal tasks require to fit the entire prompt in the physical batch | ||||
|                     if (slot.is_non_causal()) { | ||||
|                         // cannot fit the prompt in the current batch - will try next iter | ||||
|                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { | ||||
|                         if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { | ||||
|                             continue; | ||||
|                         } | ||||
|                     } | ||||
| @@ -3086,11 +3082,11 @@ struct server_context { | ||||
|                     slot.cache_tokens.resize(slot.n_past); | ||||
|  | ||||
|                     // add prompt tokens for processing in the current batch | ||||
|                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { | ||||
|                     while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) { | ||||
|                         // without pooling, we want to output the embeddings for all the tokens in the batch | ||||
|                         const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; | ||||
|  | ||||
|                         common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); | ||||
|                         common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); | ||||
|  | ||||
|                         if (slot.params.cache_prompt) { | ||||
|                             slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); | ||||
| @@ -3100,13 +3096,13 @@ struct server_context { | ||||
|                         slot.n_past++; | ||||
|                     } | ||||
|  | ||||
|                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); | ||||
|                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); | ||||
|  | ||||
|                     // entire prompt has been processed | ||||
|                     if (slot.n_past == slot.n_prompt_tokens) { | ||||
|                         slot.state = SLOT_STATE_DONE_PROMPT; | ||||
|  | ||||
|                         GGML_ASSERT(batch.n_tokens > 0); | ||||
|                         GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0); | ||||
|  | ||||
|                         common_sampler_reset(slot.smpl); | ||||
|  | ||||
| @@ -3116,27 +3112,27 @@ struct server_context { | ||||
|                         } | ||||
|  | ||||
|                         // extract the logits only for the last token | ||||
|                         batch.logits[batch.n_tokens - 1] = true; | ||||
|                         llama_batch_set_logits_last(batch.get()); | ||||
|  | ||||
|                         slot.n_decoded = 0; | ||||
|                         slot.i_batch   = batch.n_tokens - 1; | ||||
|                         slot.i_batch   = llama_batch_get_n_tokens(batch.get()) - 1; | ||||
|  | ||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); | ||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get())); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 if (batch.n_tokens >= n_batch) { | ||||
|                 if (llama_batch_get_n_tokens(batch.get()) >= n_batch) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (batch.n_tokens == 0) { | ||||
|         if (llama_batch_get_n_tokens(batch.get()) == 0) { | ||||
|             SRV_WRN("%s", "no tokens to decode\n"); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); | ||||
|         SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get())); | ||||
|  | ||||
|         if (slot_batched) { | ||||
|             // make sure we're in the right embedding mode | ||||
| @@ -3146,20 +3142,12 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         // process the created batch of tokens | ||||
|         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); | ||||
|         for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i); | ||||
|  | ||||
|             llama_batch batch_view = { | ||||
|                 n_tokens, | ||||
|                 batch.token    + i, | ||||
|                 nullptr, | ||||
|                 batch.pos      + i, | ||||
|                 batch.n_seq_id + i, | ||||
|                 batch.seq_id   + i, | ||||
|                 batch.logits   + i, | ||||
|             }; | ||||
|             llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens)); | ||||
|  | ||||
|             const int ret = llama_decode(ctx, batch_view); | ||||
|             const int ret = llama_decode(ctx, batch_view.get()); | ||||
|             metrics.on_decoded(slots); | ||||
|  | ||||
|             if (ret != 0) { | ||||
| @@ -3294,16 +3282,16 @@ struct server_context { | ||||
|                 } | ||||
|  | ||||
|                 // construct the speculation batch | ||||
|                 common_batch_clear(slot.batch_spec); | ||||
|                 common_batch_add  (slot.batch_spec, id, slot.n_past, { slot.id }, true); | ||||
|                 common_batch_clear(slot.batch_spec.get()); | ||||
|                 common_batch_add  (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true); | ||||
|  | ||||
|                 for (size_t i = 0; i < draft.size(); ++i) { | ||||
|                     common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); | ||||
|                     common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true); | ||||
|                 } | ||||
|  | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get())); | ||||
|  | ||||
|                 llama_decode(ctx, slot.batch_spec); | ||||
|                 llama_decode(ctx, slot.batch_spec.get()); | ||||
|  | ||||
|                 // the accepted tokens from the speculation | ||||
|                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); | ||||
|   | ||||
| @@ -24,7 +24,12 @@ struct llama_adapter_lora_deleter { | ||||
|     void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } | ||||
| }; | ||||
|  | ||||
| struct llama_batch_deleter { | ||||
|     void operator()(llama_batch * batch) { llama_batch_free(batch); } | ||||
| }; | ||||
|  | ||||
| typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr; | ||||
| typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr; | ||||
| typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr; | ||||
| typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr; | ||||
| typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr; | ||||
|   | ||||
| @@ -233,6 +233,14 @@ extern "C" { | ||||
|  | ||||
|     struct llama_batch; | ||||
|  | ||||
|     struct llama_batch_token_info { | ||||
|         llama_token    token; | ||||
|         llama_pos      pos; | ||||
|         int32_t        n_seq_id; | ||||
|         llama_seq_id * seq_id; | ||||
|         int8_t         logits; | ||||
|     }; | ||||
|  | ||||
|     enum llama_model_kv_override_type { | ||||
|         LLAMA_KV_OVERRIDE_TYPE_INT, | ||||
|         LLAMA_KV_OVERRIDE_TYPE_FLOAT, | ||||
| @@ -837,34 +845,44 @@ extern "C" { | ||||
|             int32_t   pos0, | ||||
|             int32_t   seq_id); | ||||
|  | ||||
|     // Get the number of tokens in the batch | ||||
|     LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch); | ||||
|  | ||||
|     LLAMA_API struct llama_batch_token_info llama_batch_get_token_info( | ||||
|             struct llama_batch * batch, | ||||
|                        int32_t   i); | ||||
|  | ||||
|     // Add text tokens to the batch | ||||
|     // First token in the list starts at position pos0 | ||||
|     // Return values: | ||||
|     //  0 : success | ||||
|     // -1 : not enough space in the batch | ||||
|     // -2 : embd is already set, cannot add text tokens | ||||
|     LLAMA_API int32_t llama_batch_add_text( | ||||
|     LLAMA_API int32_t llama_batch_add_text_token( | ||||
|             struct llama_batch * batch, | ||||
|                    llama_token * tokens, | ||||
|                        size_t    n_tokens, | ||||
|                        int32_t   pos0, | ||||
|                        int32_t   seq_id); | ||||
|  | ||||
|     // Same as llama_batch_add_text, but accepts multiple sequences | ||||
|     LLAMA_API int32_t llama_batch_add_text( | ||||
|             struct llama_batch * batch, | ||||
|                    llama_token * tokens, | ||||
|                        size_t    n_tokens, | ||||
|                        int32_t   pos0, | ||||
|                        int32_t * seq_ids, | ||||
|                        size_t    n_seq_ids); | ||||
|                    llama_token   token, | ||||
|                      llama_pos   pos, | ||||
|             const llama_seq_id * seq_ids, | ||||
|                         size_t   n_seq_ids, | ||||
|                          float   logits); | ||||
|  | ||||
|     // Set logits for the token in the ith sequence | ||||
|     // If pos == -1, logits will be set for the all tokens | ||||
|     // Returns -1 if the token is not in the batch | ||||
|     LLAMA_API int32_t llama_batch_set_logits( | ||||
|             struct llama_batch * batch, | ||||
|                        int32_t   pos, | ||||
|                        int32_t   seq_id); | ||||
|                      llama_pos   pos, | ||||
|                   llama_seq_id   seq_id); | ||||
|  | ||||
|     // Set logits for the last added token | ||||
|     // Returns -1 if there is no tokens in the batch | ||||
|     LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch); | ||||
|  | ||||
|     // Get a "view" from a number of tokens offset | ||||
|     // Return returned batch must be freed with llama_batch_free() | ||||
|     LLAMA_API struct llama_batch * llama_batch_get_view( | ||||
|             struct llama_batch * batch, | ||||
|                        int32_t   offset, | ||||
|                        int32_t   n_tokens); | ||||
|  | ||||
|     // Remove everything from the batch | ||||
|     LLAMA_API void llama_batch_clear(struct llama_batch * batch); | ||||
| @@ -878,7 +896,7 @@ extern "C" { | ||||
|     // < 0 - error. the KV cache state is restored to the state before this call | ||||
|     LLAMA_API int32_t llama_encode( | ||||
|             struct llama_context * ctx, | ||||
|               struct llama_batch   batch); | ||||
|               struct llama_batch * batch); | ||||
|  | ||||
|     // Positive return values does not mean a fatal error, but rather a warning. | ||||
|     //   0 - success | ||||
| @@ -886,7 +904,7 @@ extern "C" { | ||||
|     // < 0 - error. the KV cache state is restored to the state before this call | ||||
|     LLAMA_API int32_t llama_decode( | ||||
|             struct llama_context * ctx, | ||||
|               struct llama_batch   batch); | ||||
|               struct llama_batch * batch); | ||||
|  | ||||
|     // Set the number of threads used for decoding | ||||
|     // n_threads is the number of threads used for generation (single token) | ||||
|   | ||||
| @@ -314,6 +314,8 @@ struct llama_batch * llama_batch_get_one( | ||||
|                  int32_t   n_tokens) { | ||||
|     return new llama_batch{ | ||||
|         /*n_tokens       =*/ n_tokens, | ||||
|         /*max_tokens     =*/ n_tokens, | ||||
|         /*is_view        =*/ false, | ||||
|         /*tokens         =*/ tokens, | ||||
|         /*embd           =*/ nullptr, | ||||
|         /*pos            =*/ nullptr, | ||||
| @@ -326,6 +328,8 @@ struct llama_batch * llama_batch_get_one( | ||||
| static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { | ||||
|     llama_batch * batch = new llama_batch{ | ||||
|         /*n_tokens       =*/ 0, | ||||
|         /*max_tokens     =*/ n_tokens_alloc, | ||||
|         /*is_view        =*/ false, | ||||
|         /*tokens         =*/ nullptr, | ||||
|         /*embd           =*/ nullptr, | ||||
|         /*pos            =*/ nullptr, | ||||
| @@ -364,50 +368,46 @@ struct llama_batch * llama_batch_init_from_embd( | ||||
|             int32_t   seq_id) { | ||||
|     struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); | ||||
|     memcpy(batch->embd, embd, n_embd * sizeof(float)); | ||||
|     for (int32_t i = 0; i < n_embd; i++) { | ||||
|     for (size_t i = 0; i < n_embd; i++) { | ||||
|         batch->pos     [i] = pos0 + i; | ||||
|         batch->n_seq_id[i] = 1; | ||||
|         batch->seq_id  [i][0] = seq_id; | ||||
|     } | ||||
|     return batch; | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_add_text( | ||||
| int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) { | ||||
|     return batch->n_tokens; | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_add_text_token( | ||||
|         struct llama_batch * batch, | ||||
|                llama_token * tokens, | ||||
|                    size_t    n_tokens, | ||||
|                    int32_t   pos0, | ||||
|                    int32_t * seq_ids, | ||||
|                    size_t    n_seq_ids) { | ||||
|     if (batch->n_tokens + n_tokens > batch->n_tokens) { | ||||
|         return -1; | ||||
|                llama_token   token, | ||||
|                  llama_pos   pos, | ||||
|         const llama_seq_id * seq_ids, | ||||
|                     size_t   n_seq_ids, | ||||
|                      float   logits) { | ||||
|     if (batch->n_tokens + 1 > batch->max_tokens) { | ||||
|         return -1; // llama_batch size exceeded | ||||
|     } | ||||
|     if (batch->embd) { | ||||
|         return -2; | ||||
|         return -2; // embd is already set, cannot add text tokens | ||||
|     } | ||||
|     for (int32_t i = 0; i < n_tokens; i++) { | ||||
|         batch->token   [batch->n_tokens + i] = tokens[i]; | ||||
|         batch->pos     [batch->n_tokens + i] = pos0 + i; | ||||
|         batch->n_seq_id[batch->n_tokens + i] = n_seq_ids; | ||||
|         for (int32_t j = 0; j < n_seq_ids; j++) { | ||||
|             batch->seq_id[batch->n_tokens + i][j] = seq_ids[j]; | ||||
|         } | ||||
|     batch->token   [batch->n_tokens] = token; | ||||
|     batch->pos     [batch->n_tokens] = pos; | ||||
|     batch->n_seq_id[batch->n_tokens] = n_seq_ids; | ||||
|     for (size_t j = 0; j < n_seq_ids; j++) { | ||||
|         batch->seq_id[batch->n_tokens][j] = seq_ids[j]; | ||||
|     } | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_add_text( | ||||
|         struct llama_batch * batch, | ||||
|                llama_token * tokens, | ||||
|                    size_t    n_tokens, | ||||
|                    int32_t   pos0, | ||||
|                    int32_t   seq_id) { | ||||
|     std::array<int32_t, 1> seq_ids = { seq_id }; | ||||
|     return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size()); | ||||
|     batch->logits  [batch->n_tokens] = logits; | ||||
|     batch->n_tokens++; | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_set_logits( | ||||
|         struct llama_batch * batch, | ||||
|                    int32_t   pos, | ||||
|                    int32_t   seq_id) { | ||||
|                  llama_pos   pos, | ||||
|               llama_seq_id   seq_id) { | ||||
|     for (int32_t i = 0; i < batch->n_tokens; i++) { | ||||
|         // find the token having seq_id | ||||
|         for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { | ||||
| @@ -415,28 +415,74 @@ int32_t llama_batch_set_logits( | ||||
|                 // found the sequence | ||||
|                 if (pos == -1 || pos == batch->pos[i]) { | ||||
|                     batch->logits[i] = true; | ||||
|                     break; | ||||
|                     return 0; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return -1; // not found | ||||
| } | ||||
|  | ||||
| int32_t llama_batch_set_logits_last(struct llama_batch * batch) { | ||||
|     if (batch->n_tokens == 0) { | ||||
|         return -1; | ||||
|     } | ||||
|     batch->logits[batch->n_tokens - 1] = true; | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| void llama_batch_clear(struct llama_batch * batch) { | ||||
|     batch->n_tokens = 0; | ||||
| } | ||||
|  | ||||
| void llama_batch_free(struct llama_batch * batch) { | ||||
|     if (batch->token)    free(batch->token); | ||||
|     if (batch->embd)     free(batch->embd); | ||||
|     if (batch->pos)      free(batch->pos); | ||||
|     if (batch->n_seq_id) free(batch->n_seq_id); | ||||
|     if (batch->seq_id) { | ||||
|         for (int i = 0; batch->seq_id[i] != nullptr; ++i) { | ||||
|             free(batch->seq_id[i]); | ||||
|         } | ||||
|         free(batch->seq_id); | ||||
| struct llama_batch * llama_batch_get_view( | ||||
|         struct llama_batch * batch, | ||||
|                    int32_t   offset, | ||||
|                    int32_t   n_tokens) { | ||||
|     if (batch->embd) { | ||||
|         return nullptr; // not yet supported | ||||
|     } | ||||
|     llama_batch * batch_view = new llama_batch{ | ||||
|         /*n_tokens       =*/ n_tokens, | ||||
|         /*max_tokens     =*/ n_tokens, | ||||
|         /*is_view        =*/ true, | ||||
|         /*tokens         =*/ batch->token    + offset, | ||||
|         /*embd           =*/ nullptr, | ||||
|         /*pos            =*/ batch->pos      + offset, | ||||
|         /*n_seq_id       =*/ batch->n_seq_id + offset, | ||||
|         /*seq_id         =*/ batch->seq_id   + offset, | ||||
|         /*logits         =*/ batch->logits   + offset, | ||||
|     }; | ||||
|     return batch_view; | ||||
| } | ||||
|  | ||||
| struct llama_batch_token_info llama_batch_get_token_info( | ||||
|         struct llama_batch * batch, | ||||
|                    int32_t   i) { | ||||
|     GGML_ASSERT(i >= 0 && i < batch->n_tokens); | ||||
|     return llama_batch_token_info{ | ||||
|         /*token    =*/ batch->token   [i], | ||||
|         /*pos      =*/ batch->pos     [i], | ||||
|         /*n_seq_id =*/ batch->n_seq_id[i], | ||||
|         /*seq_id   =*/ batch->seq_id  [i], | ||||
|         /*logits   =*/ batch->logits  [i], | ||||
|     }; | ||||
| } | ||||
|  | ||||
| void llama_batch_free(struct llama_batch * batch) { | ||||
|     // do not free the members if it's a view | ||||
|     if (!batch->is_view) { | ||||
|         if (batch->token)    free(batch->token); | ||||
|         if (batch->embd)     free(batch->embd); | ||||
|         if (batch->pos)      free(batch->pos); | ||||
|         if (batch->n_seq_id) free(batch->n_seq_id); | ||||
|         if (batch->seq_id) { | ||||
|             for (int i = 0; batch->seq_id[i] != nullptr; ++i) { | ||||
|                 free(batch->seq_id[i]); | ||||
|             } | ||||
|             free(batch->seq_id); | ||||
|         } | ||||
|         if (batch->logits)   free(batch->logits); | ||||
|     } | ||||
|     if (batch->logits)   free(batch->logits); | ||||
|     delete batch; | ||||
| } | ||||
|   | ||||
| @@ -20,6 +20,8 @@ | ||||
| // | ||||
| struct llama_batch { | ||||
|     int32_t n_tokens; | ||||
|     int32_t max_tokens; | ||||
|     bool is_view; | ||||
|  | ||||
|     llama_token  *  token; | ||||
|     float        *  embd; | ||||
|   | ||||
| @@ -9978,8 +9978,8 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { | ||||
|  | ||||
| int32_t llama_encode( | ||||
|         struct llama_context * ctx, | ||||
|           struct llama_batch   batch) { | ||||
|     const int ret = llama_encode_impl(*ctx, batch); | ||||
|           struct llama_batch * batch) { | ||||
|     const int ret = llama_encode_impl(*ctx, *batch); | ||||
|     if (ret != 0) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); | ||||
|     } | ||||
| @@ -9989,8 +9989,8 @@ int32_t llama_encode( | ||||
|  | ||||
| int32_t llama_decode( | ||||
|         struct llama_context * ctx, | ||||
|           struct llama_batch   batch) { | ||||
|     const int ret = llama_decode_impl(*ctx, batch); | ||||
|           struct llama_batch * batch) { | ||||
|     const int ret = llama_decode_impl(*ctx, *batch); | ||||
|     if (ret != 0) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen