mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Server: Use multi-task for embeddings endpoint (#6001)
* use multitask for embd endpoint
* specify types
* remove redundant {"n_predict", 0}
			
			
This commit is contained in:
		| @@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) { | ||||
|         res.set_header("Access-Control-Allow-Credentials", "true"); | ||||
|         res.set_header("Access-Control-Allow-Methods",     "POST"); | ||||
|         res.set_header("Access-Control-Allow-Headers",     "*"); | ||||
|         return res.set_content("", "application/json; charset=utf-8"); | ||||
|     }); | ||||
|  | ||||
|     svr->set_logger(log_server_request); | ||||
| @@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) { | ||||
|         const json body = json::parse(req.body); | ||||
|         bool is_openai = false; | ||||
|  | ||||
|         // an input prompt can string or a list of tokens (integer) | ||||
|         std::vector<json> prompts; | ||||
|         // an input prompt can be a string or a list of tokens (integer) | ||||
|         json prompt; | ||||
|         if (body.count("input") != 0) { | ||||
|             is_openai = true; | ||||
|             if (body["input"].is_array()) { | ||||
|                 // support multiple prompts | ||||
|                 for (const json & elem : body["input"]) { | ||||
|                     prompts.push_back(elem); | ||||
|                 } | ||||
|             } else { | ||||
|                 // single input prompt | ||||
|                 prompts.push_back(body["input"]); | ||||
|             } | ||||
|             prompt = body["input"]; | ||||
|         } else if (body.count("content") != 0) { | ||||
|             // only support single prompt here | ||||
|             std::string content = body["content"]; | ||||
|             prompts.push_back(content); | ||||
|             // with "content", we only support single prompt | ||||
|             prompt = std::vector<std::string>{body["content"]}; | ||||
|         } else { | ||||
|             res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         // process all prompts | ||||
|         json responses = json::array(); | ||||
|         for (auto & prompt : prompts) { | ||||
|             // TODO @ngxson : maybe support multitask for this endpoint? | ||||
|         // create and queue the task | ||||
|         json responses; | ||||
|         { | ||||
|             const int id_task = ctx_server.queue_tasks.get_new_id(); | ||||
|  | ||||
|             ctx_server.queue_results.add_waiting_task_id(id_task); | ||||
|             ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); | ||||
|             ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true); | ||||
|  | ||||
|             // get the result | ||||
|             server_task_result result = ctx_server.queue_results.recv(id_task); | ||||
|             ctx_server.queue_results.remove_waiting_task_id(id_task); | ||||
|             if (!result.error) { | ||||
|                 // append to the responses | ||||
|                 responses.push_back(result.data); | ||||
|                 if (result.data.count("results")) { | ||||
|                     // result for multi-task | ||||
|                     responses = result.data["results"]; | ||||
|                 } else { | ||||
|                     // result for single task | ||||
|                     responses = std::vector<json>{result.data}; | ||||
|                 } | ||||
|             } else { | ||||
|                 // error received, ignore everything else | ||||
|                 res_error(res, result.data); | ||||
| @@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // write JSON response | ||||
|         json root; | ||||
|         if (is_openai) { | ||||
|             json res_oai = json::array(); | ||||
|             int i = 0; | ||||
|             for (auto & elem : responses) { | ||||
|                 res_oai.push_back(json{ | ||||
|                     {"embedding", json_value(elem, "embedding", json::array())}, | ||||
|                     {"index",     i++}, | ||||
|                     {"object",    "embedding"} | ||||
|                 }); | ||||
|             } | ||||
|             root = format_embeddings_response_oaicompat(body, res_oai); | ||||
|         } else { | ||||
|             root = responses[0]; | ||||
|         } | ||||
|         json root = is_openai | ||||
|             ? format_embeddings_response_oaicompat(body, responses) | ||||
|             : responses[0]; | ||||
|         return res.set_content(root.dump(), "application/json; charset=utf-8"); | ||||
|     }; | ||||
|  | ||||
|     auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { | ||||
|         return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { | ||||
|             res.set_content(reinterpret_cast<const char*>(content), len, mime_type); | ||||
|             return false; | ||||
|         }; | ||||
|     }; | ||||
|  | ||||
|     // | ||||
|     // Router | ||||
|     // | ||||
| @@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     // using embedded static files | ||||
|     auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { | ||||
|         return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { | ||||
|             res.set_content(reinterpret_cast<const char*>(content), len, mime_type); | ||||
|             return false; | ||||
|         }; | ||||
|     }; | ||||
|  | ||||
|     svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { | ||||
|         // TODO @ngxson : I have no idea what it is... maybe this is redundant? | ||||
|         return res.set_content("", "application/json; charset=utf-8"); | ||||
|     }); | ||||
|     svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); | ||||
|     svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); | ||||
|     svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); | ||||
|   | ||||
| @@ -529,6 +529,16 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st | ||||
| } | ||||
|  | ||||
| static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { | ||||
|     json data = json::array(); | ||||
|     int i = 0; | ||||
|     for (auto & elem : embeddings) { | ||||
|         data.push_back(json{ | ||||
|             {"embedding", json_value(elem, "embedding", json::array())}, | ||||
|             {"index",     i++}, | ||||
|             {"object",    "embedding"} | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     json res = json { | ||||
|         {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, | ||||
|         {"object", "list"}, | ||||
| @@ -536,7 +546,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso | ||||
|             {"prompt_tokens", 0}, | ||||
|             {"total_tokens", 0} | ||||
|         }}, | ||||
|         {"data", embeddings} | ||||
|         {"data", data} | ||||
|     }; | ||||
|  | ||||
|     return res; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen