mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : avoid common_batch
ggml-ci
This commit is contained in:
		| @@ -565,70 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file( | ||||
| // clear LoRA adapters from context, then apply new list of adapters | ||||
| void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora); | ||||
|  | ||||
| // | ||||
| // Batch utils | ||||
| // | ||||
|  | ||||
| // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions | ||||
| // this is meant to be temporary | ||||
| struct common_batch { | ||||
|     llama_batch_ext_ptr batch; | ||||
|     struct batch_token { | ||||
|         llama_token  token; | ||||
|         llama_seq_id seq_id; // only support single seq for now | ||||
|         bool         logits; | ||||
|     }; | ||||
|     std::vector<batch_token> tokens; | ||||
|     int n_outputs = 0; | ||||
|     common_batch() = default; | ||||
|     common_batch(int32_t n_tokens, int32_t n_seq_max) { | ||||
|         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); | ||||
|         tokens.reserve(n_tokens); | ||||
|     } | ||||
|     void clear() { | ||||
|         llama_batch_ext_clear(batch.get()); | ||||
|         tokens.clear(); | ||||
|     } | ||||
|     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { | ||||
|         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); | ||||
|         tokens.push_back({token, seq_id, logits}); | ||||
|         if (logits) { | ||||
|             n_outputs++; | ||||
|         } | ||||
|     } | ||||
|     void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { | ||||
|         llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); | ||||
|         tokens.push_back({token, seq_ids[0], logits}); | ||||
|         if (logits) { | ||||
|             n_outputs++; | ||||
|         } | ||||
|     } | ||||
|     void set_logits_last() { | ||||
|         if (!tokens.empty()) { | ||||
|             llama_batch_ext_set_output_last(batch.get()); | ||||
|             tokens.back().logits = true; | ||||
|         } | ||||
|     } | ||||
|     int32_t get_n_tokens() const { | ||||
|         return (int32_t)tokens.size(); | ||||
|     } | ||||
|     llama_batch_ext * get() { | ||||
|         return batch.get(); | ||||
|     } | ||||
|     common_batch get_view(int32_t offset, int32_t n_tokens) { | ||||
|         common_batch view; | ||||
|         view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); | ||||
|         view.tokens.reserve(n_tokens); | ||||
|         for (int32_t i = 0; i < n_tokens; i++) { | ||||
|             view.tokens.push_back(tokens[offset + i]); | ||||
|             if (tokens[offset + i].logits) { | ||||
|                 view.n_outputs++; | ||||
|             } | ||||
|         } | ||||
|         return view; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| // | ||||
| // Token utils | ||||
| // | ||||
|   | ||||
| @@ -1224,7 +1224,7 @@ struct server_slot { | ||||
|     // only used for completion/embedding/infill/rerank | ||||
|     server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; | ||||
|  | ||||
|     common_batch batch_spec; | ||||
|     llama_batch_ext_ptr batch_spec; | ||||
|  | ||||
|     llama_context * ctx = nullptr; | ||||
|     llama_context * ctx_dft = nullptr; | ||||
| @@ -1248,7 +1248,7 @@ struct server_slot { | ||||
|     int32_t n_past      = 0; | ||||
|     int32_t n_decoded   = 0; | ||||
|     int32_t n_remaining = -1; | ||||
|     int32_t i_batch     = -1; | ||||
|     int32_t i_batch     = -1; // TODO: remove and use only sequence-based sampling | ||||
|     int32_t n_predict   = -1; // TODO: disambiguate from params.n_predict | ||||
|  | ||||
|     // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated | ||||
| @@ -1796,7 +1796,7 @@ struct server_context { | ||||
|  | ||||
|     llama_context_params cparams_dft; | ||||
|  | ||||
|     common_batch batch; | ||||
|     llama_batch_ext_ptr batch; | ||||
|  | ||||
|     bool clean_kv_cache = true; | ||||
|     bool add_bos_token  = true; | ||||
| @@ -1922,7 +1922,7 @@ struct server_context { | ||||
|             slot.n_predict = params_base.n_predict; | ||||
|  | ||||
|             if (model_dft) { | ||||
|                 slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); | ||||
|                 slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); | ||||
|  | ||||
|                 slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); | ||||
|                 if (slot.ctx_dft == nullptr) { | ||||
| @@ -1958,7 +1958,7 @@ struct server_context { | ||||
|             const int32_t n_batch = llama_n_batch(ctx); | ||||
|  | ||||
|             // only a single seq_id per token is needed | ||||
|             batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); | ||||
|             batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); | ||||
|         } | ||||
|  | ||||
|         metrics.init(); | ||||
| @@ -2093,7 +2093,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         if (slot.ctx_dft) { | ||||
|             slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); | ||||
|             slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); | ||||
|         } | ||||
|  | ||||
|         slot.state = SLOT_STATE_STARTED; | ||||
| @@ -2401,7 +2401,7 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_embedding(const server_slot & slot, common_batch & batch) { | ||||
|     void send_embedding(const server_slot & slot) { | ||||
|         auto res = std::make_unique<server_task_result_embd>(); | ||||
|         res->id        = slot.id_task; | ||||
|         res->index     = slot.index; | ||||
| @@ -2410,34 +2410,40 @@ struct server_context { | ||||
|  | ||||
|         const int n_embd = llama_model_n_embd(model); | ||||
|  | ||||
|         const llama_seq_id seq_id = slot.id; | ||||
|  | ||||
|         std::vector<float> embd_res(n_embd, 0.0f); | ||||
|  | ||||
|         for (int i = 0; i < batch.get_n_tokens(); ++i) { | ||||
|             auto tok = batch.tokens[i]; | ||||
|             if (!tok.logits || tok.seq_id != slot.id) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); | ||||
|             if (embd == NULL) { | ||||
|                 embd = llama_get_embeddings_ith(ctx, i); | ||||
|             } | ||||
|         if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, seq_id); | ||||
|  | ||||
|             if (embd == NULL) { | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); | ||||
|                 SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); | ||||
|  | ||||
|                 res->embedding.push_back(std::vector<float>(n_embd, 0.0f)); | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             // normalize only when there is pooling | ||||
|             // TODO: configurable | ||||
|             if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { | ||||
|                 common_embd_normalize(embd, embd_res.data(), n_embd, 2); | ||||
|                 res->embedding.push_back(embd_res); | ||||
|             } else { | ||||
|                 res->embedding.push_back({ embd, embd + n_embd }); | ||||
|             } | ||||
|             common_embd_normalize(embd, embd_res.data(), n_embd, 2); | ||||
|             res->embedding.push_back(embd_res); | ||||
|         } else { | ||||
|             GGML_ABORT("embeddings without pooling is not supported yet"); | ||||
|             //for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { | ||||
|             //    auto tok = batch.tokens[i]; | ||||
|             //    if (!tok.logits || tok.seq_id != slot.id) { | ||||
|             //        continue; | ||||
|             //    } | ||||
|  | ||||
|             //    const float * embd = llama_get_embeddings_ith(ctx, tok.seq_id); | ||||
|             //    if (embd == NULL) { | ||||
|             //        SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); | ||||
|  | ||||
|             //        res->embedding.push_back(std::vector<float>(n_embd, 0.0f)); | ||||
|             //        continue; | ||||
|             //    } | ||||
|  | ||||
|             //    res->embedding.push_back({ embd, embd + n_embd }); | ||||
|             //} | ||||
|         } | ||||
|  | ||||
|         SLT_DBG(slot, "%s", "sending embeddings\n"); | ||||
| @@ -2445,30 +2451,20 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_rerank(const server_slot & slot, common_batch & batch) { | ||||
|     void send_rerank(const server_slot & slot) { | ||||
|         auto res = std::make_unique<server_task_result_rerank>(); | ||||
|         res->id    = slot.id_task; | ||||
|         res->index = slot.index; | ||||
|         res->n_tokens = slot.n_prompt_tokens; | ||||
|  | ||||
|         for (int i = 0; i < batch.get_n_tokens(); ++i) { | ||||
|             auto tok = batch.tokens[i]; | ||||
|             if (!tok.logits || tok.seq_id != slot.id) { | ||||
|                 continue; | ||||
|             } | ||||
|         const llama_seq_id seq_id = slot.id; | ||||
|  | ||||
|             const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); | ||||
|             if (embd == NULL) { | ||||
|                 embd = llama_get_embeddings_ith(ctx, i); | ||||
|             } | ||||
|  | ||||
|             if (embd == NULL) { | ||||
|                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); | ||||
|  | ||||
|                 res->score = -1e6; | ||||
|                 continue; | ||||
|             } | ||||
|         const float * embd = llama_get_embeddings_seq(ctx, seq_id); | ||||
|         if (embd == NULL) { | ||||
|             SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); | ||||
|  | ||||
|             res->score = -1e6; | ||||
|         } else { | ||||
|             res->score = embd[0]; | ||||
|         } | ||||
|  | ||||
| @@ -2854,7 +2850,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         // start populating the batch for this iteration | ||||
|         batch.clear(); | ||||
|         llama_batch_ext_clear(batch.get()); | ||||
|  | ||||
|         // track if given slot can be batched with slots already in the batch | ||||
|         server_slot * slot_batched = nullptr; | ||||
| @@ -2876,9 +2872,9 @@ struct server_context { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             slot.i_batch = batch.get_n_tokens(); | ||||
|             slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); | ||||
|  | ||||
|             batch.add_text(slot.sampled, slot.n_past, slot.id, true); | ||||
|             llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, &slot.id, 1, true); | ||||
|  | ||||
|             slot.n_past += 1; | ||||
|  | ||||
| @@ -2895,7 +2891,7 @@ struct server_context { | ||||
|         int32_t n_ubatch = llama_n_ubatch(ctx); | ||||
|  | ||||
|         // next, batch any pending prompts without exceeding n_batch | ||||
|         if (params_base.cont_batching || batch.get_n_tokens() == 0) { | ||||
|         if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { | ||||
|             for (auto & slot : slots) { | ||||
|                 // check if we can batch this slot with the previous one | ||||
|                 if (slot.is_processing()) { | ||||
| @@ -3061,7 +3057,7 @@ struct server_context { | ||||
|                     // non-causal tasks require to fit the entire prompt in the physical batch | ||||
|                     if (slot.is_non_causal()) { | ||||
|                         // cannot fit the prompt in the current batch - will try next iter | ||||
|                         if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) { | ||||
|                         if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { | ||||
|                             continue; | ||||
|                         } | ||||
|                     } | ||||
| @@ -3081,11 +3077,12 @@ struct server_context { | ||||
|                     slot.cache_tokens.resize(slot.n_past); | ||||
|  | ||||
|                     // add prompt tokens for processing in the current batch | ||||
|                     while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) { | ||||
|                     while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { | ||||
|                         // without pooling, we want to output the embeddings for all the tokens in the batch | ||||
|                         const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; | ||||
|  | ||||
|                         batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); | ||||
|                         //batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); | ||||
|                         llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, &slot.id, 1, need_embd); | ||||
|  | ||||
|                         if (slot.params.cache_prompt) { | ||||
|                             slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); | ||||
| @@ -3095,13 +3092,14 @@ struct server_context { | ||||
|                         slot.n_past++; | ||||
|                     } | ||||
|  | ||||
|                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); | ||||
|                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", | ||||
|                             slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); | ||||
|  | ||||
|                     // entire prompt has been processed | ||||
|                     if (slot.n_past == slot.n_prompt_tokens) { | ||||
|                         slot.state = SLOT_STATE_DONE_PROMPT; | ||||
|  | ||||
|                         GGML_ASSERT(batch.get_n_tokens() > 0); | ||||
|                         GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); | ||||
|  | ||||
|                         common_sampler_reset(slot.smpl); | ||||
|  | ||||
| @@ -3111,27 +3109,28 @@ struct server_context { | ||||
|                         } | ||||
|  | ||||
|                         // extract the logits only for the last token | ||||
|                         batch.set_logits_last(); | ||||
|                         //batch.set_logits_last(); | ||||
|                         llama_batch_ext_set_output_last(batch.get()); | ||||
|  | ||||
|                         slot.n_decoded = 0; | ||||
|                         slot.i_batch   = batch.get_n_tokens() - 1; | ||||
|                         slot.i_batch   = llama_batch_ext_get_n_tokens(batch.get()) - 1; | ||||
|  | ||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens()); | ||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 if (batch.get_n_tokens() >= n_batch) { | ||||
|                 if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (batch.get_n_tokens() == 0) { | ||||
|         if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { | ||||
|             SRV_WRN("%s", "no tokens to decode\n"); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens()); | ||||
|         SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); | ||||
|  | ||||
|         if (slot_batched) { | ||||
|             // make sure we're in the right embedding mode | ||||
| @@ -3141,10 +3140,10 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         // process the created batch of tokens | ||||
|         for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); | ||||
|         for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); | ||||
|  | ||||
|             common_batch batch_view = batch.get_view(i, n_tokens); | ||||
|             llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); | ||||
|  | ||||
|             const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||
|             metrics.on_decoded(slots); | ||||
| @@ -3177,14 +3176,14 @@ struct server_context { | ||||
|                 if (slot.state == SLOT_STATE_DONE_PROMPT) { | ||||
|                     if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { | ||||
|                         // prompt evaluated for embedding | ||||
|                         send_embedding(slot, batch_view); | ||||
|                         send_embedding(slot); | ||||
|                         slot.release(); | ||||
|                         slot.i_batch = -1; | ||||
|                         continue; // continue loop of slots | ||||
|                     } | ||||
|  | ||||
|                     if (slot.task_type == SERVER_TASK_TYPE_RERANK) { | ||||
|                         send_rerank(slot, batch_view); | ||||
|                         send_rerank(slot); | ||||
|                         slot.release(); | ||||
|                         slot.i_batch = -1; | ||||
|                         continue; // continue loop of slots | ||||
| @@ -3281,14 +3280,17 @@ struct server_context { | ||||
|                 } | ||||
|  | ||||
|                 // construct the speculation batch | ||||
|                 slot.batch_spec.clear(); | ||||
|                 slot.batch_spec.add_text(id, slot.n_past, slot.id, true); | ||||
|                 //slot.batch_spec.clear(); | ||||
|                 //slot.batch_spec.add_text(id, slot.n_past, slot.id, true); | ||||
|                 llama_batch_ext_clear(slot.batch_spec.get()); | ||||
|                 llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, &slot.id, 1, true); | ||||
|  | ||||
|                 for (size_t i = 0; i < draft.size(); ++i) { | ||||
|                     slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); | ||||
|                     //slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); | ||||
|                     llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, &slot.id, 1, true); | ||||
|                 } | ||||
|  | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); | ||||
|  | ||||
|                 llama_decode_ext(ctx, slot.batch_spec.get()); | ||||
|  | ||||
| @@ -4147,6 +4149,11 @@ int main(int argc, char ** argv) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { | ||||
|             res_error(res, format_error_response("Pooling type 'none' is not yet supported. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         // for the shape of input/content, see tokenize_input_prompts() | ||||
|         json prompt; | ||||
|         if (body.count("input") != 0) { | ||||
| @@ -4241,6 +4248,11 @@ int main(int argc, char ** argv) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { | ||||
|             res_error(res, format_error_response("Pooling type 'none' cannot be used with reranking. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         const json body = json::parse(req.body); | ||||
|  | ||||
|         // TODO: implement | ||||
|   | ||||
| @@ -88,13 +88,19 @@ def test_embedding_pooling_none(): | ||||
|     res = server.make_request("POST", "/embeddings", data={ | ||||
|         "input": "hello hello hello", | ||||
|     }) | ||||
|     assert res.status_code == 200 | ||||
|     assert 'embedding' in res.body[0] | ||||
|     assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special | ||||
|  | ||||
|     # make sure embedding vector is not normalized | ||||
|     for x in res.body[0]['embedding']: | ||||
|         assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON | ||||
|     # /embeddings does not support pooling type 'none' | ||||
|     assert res.status_code == 400 | ||||
|     assert "error" in res.body | ||||
|  | ||||
|     # TODO: re-enable when we figure out how to support pooling type 'none' | ||||
|     #assert res.status_code == 200 | ||||
|     #assert 'embedding' in res.body[0] | ||||
|     #assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special | ||||
|  | ||||
|     ## make sure embedding vector is not normalized | ||||
|     #for x in res.body[0]['embedding']: | ||||
|     #    assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON | ||||
|  | ||||
|  | ||||
| def test_embedding_pooling_none_oai(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov