mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server : embeddings compatibility for OpenAI (#5190)
This commit is contained in:
		| @@ -206,3 +206,18 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res | ||||
|  | ||||
|     return std::vector<json>({ret}); | ||||
| } | ||||
|  | ||||
| inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings) | ||||
| { | ||||
|     json res = | ||||
|         json{ | ||||
|             {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, | ||||
|             {"object", "list"}, | ||||
|             {"usage", | ||||
|                 json{{"prompt_tokens", 0}, | ||||
|                      {"total_tokens", 0}}}, | ||||
|             {"data", embeddings} | ||||
|         }; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -2929,6 +2929,66 @@ int main(int argc, char **argv) | ||||
|                 return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); | ||||
|             }); | ||||
|  | ||||
|     svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) | ||||
|             { | ||||
|                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||||
|                 const json body = json::parse(req.body); | ||||
|  | ||||
|                 json prompt; | ||||
|                 if (body.count("input") != 0) | ||||
|                 { | ||||
|                     prompt = body["input"]; | ||||
|                     // batch | ||||
|                     if(prompt.is_array()) { | ||||
|                         json data = json::array(); | ||||
|                         int i = 0; | ||||
|                         for (const json &elem : prompt) { | ||||
|                             const int task_id = llama.queue_tasks.get_new_id(); | ||||
|                             llama.queue_results.add_waiting_task_id(task_id); | ||||
|                             llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1); | ||||
|  | ||||
|                             // get the result | ||||
|                             task_result result = llama.queue_results.recv(task_id); | ||||
|                             llama.queue_results.remove_waiting_task_id(task_id); | ||||
|  | ||||
|                             json embedding = json{ | ||||
|                                 {"embedding", json_value(result.result_json, "embedding", json::array())}, | ||||
|                                 {"index", i++}, | ||||
|                                 {"object", "embedding"} | ||||
|                             }; | ||||
|                             data.push_back(embedding); | ||||
|                         } | ||||
|                         json result = format_embeddings_response_oaicompat(body, data); | ||||
|                         return res.set_content(result.dump(), "application/json; charset=utf-8"); | ||||
|                     } | ||||
|                 } | ||||
|                 else | ||||
|                 { | ||||
|                     prompt = ""; | ||||
|                 } | ||||
|  | ||||
|                 // create and queue the task | ||||
|                 const int task_id = llama.queue_tasks.get_new_id(); | ||||
|                 llama.queue_results.add_waiting_task_id(task_id); | ||||
|                 llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1); | ||||
|  | ||||
|                 // get the result | ||||
|                 task_result result = llama.queue_results.recv(task_id); | ||||
|                 llama.queue_results.remove_waiting_task_id(task_id); | ||||
|  | ||||
|                 json data = json::array({json{ | ||||
|                         {"embedding", json_value(result.result_json, "embedding", json::array())}, | ||||
|                         {"index", 0}, | ||||
|                         {"object", "embedding"} | ||||
|                     }} | ||||
|                 ); | ||||
|  | ||||
|                 json root = format_embeddings_response_oaicompat(body, data); | ||||
|  | ||||
|                 // send the result | ||||
|                 return res.set_content(root.dump(), "application/json; charset=utf-8"); | ||||
|             }); | ||||
|  | ||||
|     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? | ||||
|     //     "Bus error: 10" - this is on macOS, it does not crash on Linux | ||||
|     //std::thread t2([&]() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Wu Jian Ping
					Wu Jian Ping