mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : ensure batches are either all embed or all completion (#8420)
* make sure batches are all embed or all non-embed * non-embedding batch for sampled tokens; fix unused params warning
This commit is contained in:
		| @@ -2005,6 +2005,11 @@ struct server_context { | ||||
|         int32_t n_batch  = llama_n_batch(ctx); | ||||
|         int32_t n_ubatch = llama_n_ubatch(ctx); | ||||
|  | ||||
|         // track if this is an embedding or non-embedding batch | ||||
|         // if we've added sampled tokens above, we are in non-embedding mode | ||||
|         // -1: none, 0: non-embedding, 1: embedding | ||||
|         int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; | ||||
|  | ||||
|         // next, batch any pending prompts without exceeding n_batch | ||||
|         if (params.cont_batching || batch.n_tokens == 0) { | ||||
|             for (auto & slot : slots) { | ||||
| @@ -2175,6 +2180,14 @@ struct server_context { | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     // check that we are in the right batch_type, if not defer the slot | ||||
|                     bool slot_type = slot.embedding ? 1 : 0; | ||||
|                     if (batch_type == -1) { | ||||
|                         batch_type = slot_type; | ||||
|                     } else if (batch_type != slot_type) { | ||||
|                         continue; | ||||
|                     } | ||||
|  | ||||
|                     // keep only the common part | ||||
|                     int p0 = (int) system_tokens.size() + slot.n_past; | ||||
|                     if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { | ||||
| @@ -2276,6 +2289,9 @@ struct server_context { | ||||
|             {"n_tokens", batch.n_tokens}, | ||||
|         }); | ||||
|  | ||||
|         // make sure we're in the right embedding mode | ||||
|         llama_set_embeddings(ctx, batch_type == 1); | ||||
|  | ||||
|         // process the created batch of tokens | ||||
|         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); | ||||
| @@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) { | ||||
|     }; | ||||
|  | ||||
|     const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { | ||||
|         if (ctx_server.params.embedding) { | ||||
|             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||||
|  | ||||
|         json data = json::parse(req.body); | ||||
| @@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) { | ||||
|     }; | ||||
|  | ||||
|     const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { | ||||
|         if (ctx_server.params.embedding) { | ||||
|             res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||||
|         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); | ||||
|  | ||||
| @@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) { | ||||
|     }; | ||||
|  | ||||
|     const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { | ||||
|         if (ctx_server.params.embedding) { | ||||
|             res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||||
|  | ||||
|         json data = json::parse(req.body); | ||||
| @@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) { | ||||
|         return res.set_content(data.dump(), "application/json; charset=utf-8"); | ||||
|     }; | ||||
|  | ||||
|     const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { | ||||
|     const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { | ||||
|         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||||
|         if (!params.embedding) { | ||||
|             res.status = 501; | ||||
|             res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         const json body = json::parse(req.body); | ||||
|         bool is_openai = false; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Douglas Hanley
					Douglas Hanley