mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	rework, targeting llama-server
This commit is contained in:
		| @@ -1215,7 +1215,7 @@ struct server_slot { | ||||
|     // only used for completion/embedding/infill/rerank | ||||
|     server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; | ||||
|  | ||||
|     llama_batch batch_spec = {}; | ||||
|     llama_batch_ptr batch_spec; | ||||
|  | ||||
|     llama_context * ctx = nullptr; | ||||
|     llama_context * ctx_dft = nullptr; | ||||
| @@ -1787,7 +1787,7 @@ struct server_context { | ||||
|  | ||||
|     llama_context_params cparams_dft; | ||||
|  | ||||
|     llama_batch batch = {}; | ||||
|     llama_batch_ptr batch; | ||||
|  | ||||
|     bool clean_kv_cache = true; | ||||
|     bool add_bos_token  = true; | ||||
| @@ -1820,11 +1820,7 @@ struct server_context { | ||||
|  | ||||
|             common_speculative_free(slot.spec); | ||||
|             slot.spec = nullptr; | ||||
|  | ||||
|             llama_batch_free(slot.batch_spec); | ||||
|         } | ||||
|  | ||||
|         llama_batch_free(batch); | ||||
|     } | ||||
|  | ||||
|     bool load_model(const common_params & params) { | ||||
| @@ -1944,7 +1940,7 @@ struct server_context { | ||||
|             slot.n_predict = params_base.n_predict; | ||||
|  | ||||
|             if (model_dft) { | ||||
|                 slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); | ||||
|                 slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1)); | ||||
|  | ||||
|                 slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); | ||||
|                 if (slot.ctx_dft == nullptr) { | ||||
| @@ -1969,7 +1965,7 @@ struct server_context { | ||||
|  | ||||
|             slot.reset(); | ||||
|  | ||||
|             slots.push_back(slot); | ||||
|             slots.push_back(std::move(slot)); | ||||
|         } | ||||
|  | ||||
|         default_generation_settings_for_props = slots[0].to_json(); | ||||
| @@ -1980,7 +1976,7 @@ struct server_context { | ||||
|             const int32_t n_batch = llama_n_batch(ctx); | ||||
|  | ||||
|             // only a single seq_id per token is needed | ||||
|             batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); | ||||
|             batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1)); | ||||
|         } | ||||
|  | ||||
|         metrics.init(); | ||||
| @@ -2098,9 +2094,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         if (slot.ctx_dft) { | ||||
|             llama_batch_free(slot.batch_spec); | ||||
|  | ||||
|             slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); | ||||
|             slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1)); | ||||
|         } | ||||
|  | ||||
|         slot.state = SLOT_STATE_STARTED; | ||||
| @@ -2408,7 +2402,7 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_embedding(const server_slot & slot, const llama_batch & batch) { | ||||
|     void send_embedding(const server_slot & slot, llama_batch_ptr & batch) { | ||||
|         auto res = std::make_unique<server_task_result_embd>(); | ||||
|         res->id        = slot.id_task; | ||||
|         res->index     = slot.index; | ||||
| @@ -2419,18 +2413,19 @@ struct server_context { | ||||
|  | ||||
|         std::vector<float> embd_res(n_embd, 0.0f); | ||||
|  | ||||
|         for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { | ||||
|         for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { | ||||
|             llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); | ||||
|             if (!tok.logits || tok.seq_id[0] != slot.id) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); | ||||
|             if (embd == NULL) { | ||||
|                 embd = llama_get_embeddings_ith(ctx, i); | ||||
|             } | ||||
|  | ||||
|             if (embd == NULL) { | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); | ||||
|  | ||||
|                 res->embedding.push_back(std::vector<float>(n_embd, 0.0f)); | ||||
|                 continue; | ||||
| @@ -2451,24 +2446,25 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_rerank(const server_slot & slot, const llama_batch & batch) { | ||||
|     void send_rerank(const server_slot & slot, llama_batch_ptr & batch) { | ||||
|         auto res = std::make_unique<server_task_result_rerank>(); | ||||
|         res->id    = slot.id_task; | ||||
|         res->index = slot.index; | ||||
|         res->n_tokens = slot.n_prompt_tokens; | ||||
|  | ||||
|         for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { | ||||
|         for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { | ||||
|             llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); | ||||
|             if (!tok.logits || tok.seq_id[0] != slot.id) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); | ||||
|             if (embd == NULL) { | ||||
|                 embd = llama_get_embeddings_ith(ctx, i); | ||||
|             } | ||||
|  | ||||
|             if (embd == NULL) { | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); | ||||
|  | ||||
|                 res->score = -1e6; | ||||
|                 continue; | ||||
| @@ -2859,7 +2855,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         // start populating the batch for this iteration | ||||
|         common_batch_clear(batch); | ||||
|         common_batch_clear(batch.get()); | ||||
|  | ||||
|         // track if given slot can be batched with slots already in the batch | ||||
|         server_slot * slot_batched = nullptr; | ||||
| @@ -2881,9 +2877,9 @@ struct server_context { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             slot.i_batch = batch.n_tokens; | ||||
|             slot.i_batch = llama_batch_get_n_tokens(batch.get()); | ||||
|  | ||||
|             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); | ||||
|             common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true); | ||||
|  | ||||
|             slot.n_past += 1; | ||||
|  | ||||
| @@ -2900,7 +2896,7 @@ struct server_context { | ||||
|         int32_t n_ubatch = llama_n_ubatch(ctx); | ||||
|  | ||||
|         // next, batch any pending prompts without exceeding n_batch | ||||
|         if (params_base.cont_batching || batch.n_tokens == 0) { | ||||
|         if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) { | ||||
|             for (auto & slot : slots) { | ||||
|                 // check if we can batch this slot with the previous one | ||||
|                 if (slot.is_processing()) { | ||||
| @@ -3066,7 +3062,7 @@ struct server_context { | ||||
|                     // non-causal tasks require to fit the entire prompt in the physical batch | ||||
|                     if (slot.is_non_causal()) { | ||||
|                         // cannot fit the prompt in the current batch - will try next iter | ||||
|                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { | ||||
|                         if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { | ||||
|                             continue; | ||||
|                         } | ||||
|                     } | ||||
| @@ -3086,11 +3082,11 @@ struct server_context { | ||||
|                     slot.cache_tokens.resize(slot.n_past); | ||||
|  | ||||
|                     // add prompt tokens for processing in the current batch | ||||
|                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { | ||||
|                     while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) { | ||||
|                         // without pooling, we want to output the embeddings for all the tokens in the batch | ||||
|                         const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; | ||||
|  | ||||
|                         common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); | ||||
|                         common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); | ||||
|  | ||||
|                         if (slot.params.cache_prompt) { | ||||
|                             slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); | ||||
| @@ -3100,13 +3096,13 @@ struct server_context { | ||||
|                         slot.n_past++; | ||||
|                     } | ||||
|  | ||||
|                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); | ||||
|                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); | ||||
|  | ||||
|                     // entire prompt has been processed | ||||
|                     if (slot.n_past == slot.n_prompt_tokens) { | ||||
|                         slot.state = SLOT_STATE_DONE_PROMPT; | ||||
|  | ||||
|                         GGML_ASSERT(batch.n_tokens > 0); | ||||
|                         GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0); | ||||
|  | ||||
|                         common_sampler_reset(slot.smpl); | ||||
|  | ||||
| @@ -3116,27 +3112,27 @@ struct server_context { | ||||
|                         } | ||||
|  | ||||
|                         // extract the logits only for the last token | ||||
|                         batch.logits[batch.n_tokens - 1] = true; | ||||
|                         llama_batch_set_logits_last(batch.get()); | ||||
|  | ||||
|                         slot.n_decoded = 0; | ||||
|                         slot.i_batch   = batch.n_tokens - 1; | ||||
|                         slot.i_batch   = llama_batch_get_n_tokens(batch.get()) - 1; | ||||
|  | ||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); | ||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get())); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 if (batch.n_tokens >= n_batch) { | ||||
|                 if (llama_batch_get_n_tokens(batch.get()) >= n_batch) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (batch.n_tokens == 0) { | ||||
|         if (llama_batch_get_n_tokens(batch.get()) == 0) { | ||||
|             SRV_WRN("%s", "no tokens to decode\n"); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); | ||||
|         SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get())); | ||||
|  | ||||
|         if (slot_batched) { | ||||
|             // make sure we're in the right embedding mode | ||||
| @@ -3146,20 +3142,12 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         // process the created batch of tokens | ||||
|         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); | ||||
|         for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i); | ||||
|  | ||||
|             llama_batch batch_view = { | ||||
|                 n_tokens, | ||||
|                 batch.token    + i, | ||||
|                 nullptr, | ||||
|                 batch.pos      + i, | ||||
|                 batch.n_seq_id + i, | ||||
|                 batch.seq_id   + i, | ||||
|                 batch.logits   + i, | ||||
|             }; | ||||
|             llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens)); | ||||
|  | ||||
|             const int ret = llama_decode(ctx, batch_view); | ||||
|             const int ret = llama_decode(ctx, batch_view.get()); | ||||
|             metrics.on_decoded(slots); | ||||
|  | ||||
|             if (ret != 0) { | ||||
| @@ -3294,16 +3282,16 @@ struct server_context { | ||||
|                 } | ||||
|  | ||||
|                 // construct the speculation batch | ||||
|                 common_batch_clear(slot.batch_spec); | ||||
|                 common_batch_add  (slot.batch_spec, id, slot.n_past, { slot.id }, true); | ||||
|                 common_batch_clear(slot.batch_spec.get()); | ||||
|                 common_batch_add  (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true); | ||||
|  | ||||
|                 for (size_t i = 0; i < draft.size(); ++i) { | ||||
|                     common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); | ||||
|                     common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true); | ||||
|                 } | ||||
|  | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get())); | ||||
|  | ||||
|                 llama_decode(ctx, slot.batch_spec); | ||||
|                 llama_decode(ctx, slot.batch_spec.get()); | ||||
|  | ||||
|                 // the accepted tokens from the speculation | ||||
|                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen