mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server : (embeddings) using same format for "input" and "content" (#10872)
* server : (embeddings) using same format for "input" and "content" * fix test case * handle empty input case * fix test
This commit is contained in:
		| @@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) { | |||||||
|         const json body = json::parse(req.body); |         const json body = json::parse(req.body); | ||||||
|         bool oaicompat = false; |         bool oaicompat = false; | ||||||
|  |  | ||||||
|         // an input prompt can be a string or a list of tokens (integer) |         // for the shape of input/content, see tokenize_input_prompts() | ||||||
|         json prompt; |         json prompt; | ||||||
|         if (body.count("input") != 0) { |         if (body.contains("input")) { | ||||||
|             oaicompat = true; |             oaicompat = true; | ||||||
|             prompt = body.at("input"); |             prompt = body.at("input"); | ||||||
|         } else if (body.count("content") != 0) { |         } else if (body.contains("content")) { | ||||||
|             // with "content", we only support single prompt |             oaicompat = false; | ||||||
|             prompt = std::vector<std::string>{body.at("content")}; |             prompt = body.at("content"); | ||||||
|         } else { |         } else { | ||||||
|             res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); |             res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); | ||||||
|  |         for (const auto & tokens : tokenized_prompts) { | ||||||
|  |             // this check is necessary for models that do not add BOS token to the input | ||||||
|  |             if (tokens.empty()) { | ||||||
|  |                 res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|         // create and queue the task |         // create and queue the task | ||||||
|         json responses = json::array(); |         json responses = json::array(); | ||||||
|         bool error = false; |         bool error = false; | ||||||
|         { |         { | ||||||
|             std::vector<server_task> tasks; |             std::vector<server_task> tasks; | ||||||
|             std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); |  | ||||||
|             for (size_t i = 0; i < tokenized_prompts.size(); i++) { |             for (size_t i = 0; i < tokenized_prompts.size(); i++) { | ||||||
|                 server_task task   = server_task(SERVER_TASK_TYPE_EMBEDDING); |                 server_task task   = server_task(SERVER_TASK_TYPE_EMBEDDING); | ||||||
|                 task.id            = ctx_server.queue_tasks.get_new_id(); |                 task.id            = ctx_server.queue_tasks.get_new_id(); | ||||||
|   | |||||||
| @@ -45,6 +45,35 @@ def test_embedding_multiple(): | |||||||
|         assert len(d['embedding']) > 1 |         assert len(d['embedding']) > 1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "content,is_multi_prompt", | ||||||
|  |     [ | ||||||
|  |         # single prompt | ||||||
|  |         ("string", False), | ||||||
|  |         ([12, 34, 56], False), | ||||||
|  |         ([12, 34, "string", 56, 78], False), | ||||||
|  |         # multiple prompts | ||||||
|  |         (["string1", "string2"], True), | ||||||
|  |         (["string1", [12, 34, 56]], True), | ||||||
|  |         ([[12, 34, 56], [12, 34, 56]], True), | ||||||
|  |         ([[12, 34, 56], [12, "string", 34, 56]], True), | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | def test_embedding_mixed_input(content, is_multi_prompt: bool): | ||||||
|  |     global server | ||||||
|  |     server.start() | ||||||
|  |     res = server.make_request("POST", "/embeddings", data={"content": content}) | ||||||
|  |     assert res.status_code == 200 | ||||||
|  |     if is_multi_prompt: | ||||||
|  |         assert len(res.body) == len(content) | ||||||
|  |         for d in res.body: | ||||||
|  |             assert 'embedding' in d | ||||||
|  |             assert len(d['embedding']) > 1 | ||||||
|  |     else: | ||||||
|  |         assert 'embedding' in res.body | ||||||
|  |         assert len(res.body['embedding']) > 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_embedding_openai_library_single(): | def test_embedding_openai_library_single(): | ||||||
|     global server |     global server | ||||||
|     server.start() |     server.start() | ||||||
| @@ -102,8 +131,8 @@ def test_same_prompt_give_same_result(): | |||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "content,n_tokens", |     "content,n_tokens", | ||||||
|     [ |     [ | ||||||
|         ("I believe the meaning of life is", 7), |         ("I believe the meaning of life is", 9), | ||||||
|         ("This is a test", 4), |         ("This is a test", 6), | ||||||
|     ] |     ] | ||||||
| ) | ) | ||||||
| def test_embedding_usage_single(content, n_tokens): | def test_embedding_usage_single(content, n_tokens): | ||||||
| @@ -126,4 +155,4 @@ def test_embedding_usage_multiple(): | |||||||
|     }) |     }) | ||||||
|     assert res.status_code == 200 |     assert res.status_code == 200 | ||||||
|     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] |     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] | ||||||
|     assert res.body['usage']['prompt_tokens'] == 2 * 7 |     assert res.body['usage']['prompt_tokens'] == 2 * 9 | ||||||
|   | |||||||
| @@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_ | |||||||
|  * and multiple prompts (multi-tasks): |  * and multiple prompts (multi-tasks): | ||||||
|  * - "prompt": ["string1", "string2"] |  * - "prompt": ["string1", "string2"] | ||||||
|  * - "prompt": ["string1", [12, 34, 56]] |  * - "prompt": ["string1", [12, 34, 56]] | ||||||
|  |  * - "prompt": [[12, 34, 56], [78, 90, 12]] | ||||||
|  * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] |  * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] | ||||||
|  */ |  */ | ||||||
| static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { | static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen