mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : add llama2 chat template (#5425)
* server: add mistral chat template * server: fix typo * server: rename template mistral to llama2 * server: format_llama2: remove BOS * server: validate "--chat-template" argument * server: clean up using_chatml variable Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
		| @@ -15,9 +15,13 @@ | ||||
| using json = nlohmann::json; | ||||
|  | ||||
| inline static json oaicompat_completion_params_parse( | ||||
|     const json &body /* openai api json semantics */) | ||||
|     const json &body, /* openai api json semantics */ | ||||
|     const std::string &chat_template) | ||||
| { | ||||
|     json llama_params; | ||||
|     std::string formatted_prompt = chat_template == "chatml" | ||||
|         ? format_chatml(body["messages"])  // OpenAI 'messages' to chatml (with <|im_start|>,...) | ||||
|         : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...) | ||||
|  | ||||
|     llama_params["__oaicompat"] = true; | ||||
|  | ||||
| @@ -30,7 +34,7 @@ inline static json oaicompat_completion_params_parse( | ||||
|     // https://platform.openai.com/docs/api-reference/chat/create | ||||
|     llama_sampling_params default_sparams; | ||||
|     llama_params["model"]             = json_value(body, "model", std::string("unknown")); | ||||
|     llama_params["prompt"]            = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' | ||||
|     llama_params["prompt"]            = formatted_prompt; | ||||
|     llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false); | ||||
|     llama_params["temperature"]       = json_value(body, "temperature", 0.0); | ||||
|     llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k); | ||||
|   | ||||
| @@ -36,6 +36,7 @@ struct server_params | ||||
|     std::string hostname = "127.0.0.1"; | ||||
|     std::vector<std::string> api_keys; | ||||
|     std::string public_path = "examples/server/public"; | ||||
|     std::string chat_template = "chatml"; | ||||
|     int32_t port = 8080; | ||||
|     int32_t read_timeout = 600; | ||||
|     int32_t write_timeout = 600; | ||||
| @@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | ||||
|     printf("                            types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); | ||||
|     printf("  -gan N, --grp-attn-n N    set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); | ||||
|     printf("  -gaw N, --grp-attn-w N    set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); | ||||
|     printf("  --chat-template FORMAT_NAME"); | ||||
|     printf("                            set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str()); | ||||
|     printf("\n"); | ||||
| } | ||||
|  | ||||
| @@ -2290,6 +2293,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | ||||
|             log_set_target(stdout); | ||||
|             LOG_INFO("logging to file is disabled.", {}); | ||||
|         } | ||||
|         else if (arg == "--chat-template") | ||||
|         { | ||||
|             if (++i >= argc) | ||||
|             { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             std::string value(argv[i]); | ||||
|             if (value != "chatml" && value != "llama2") { | ||||
|                 fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             sparams.chat_template = value; | ||||
|         } | ||||
|         else if (arg == "--override-kv") | ||||
|         { | ||||
|             if (++i >= argc) { | ||||
| @@ -2743,13 +2761,13 @@ int main(int argc, char **argv) | ||||
|  | ||||
|  | ||||
|     // TODO: add mount point without "/v1" prefix -- how? | ||||
|     svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) | ||||
|     svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) | ||||
|             { | ||||
|                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||||
|                 if (!validate_api_key(req, res)) { | ||||
|                     return; | ||||
|                 } | ||||
|                 json data = oaicompat_completion_params_parse(json::parse(req.body)); | ||||
|                 json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); | ||||
|  | ||||
|                 const int task_id = llama.queue_tasks.get_new_id(); | ||||
|                 llama.queue_results.add_waiting_task_id(task_id); | ||||
|   | ||||
| @@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v | ||||
|         : default_value; | ||||
| } | ||||
|  | ||||
| inline std::string format_llama2(std::vector<json> messages) | ||||
| { | ||||
|     std::ostringstream output; | ||||
|     bool is_inside_turn = false; | ||||
|  | ||||
|     for (auto it = messages.begin(); it != messages.end(); ++it) { | ||||
|         if (!is_inside_turn) { | ||||
|             output << "[INST] "; | ||||
|         } | ||||
|         std::string role    = json_value(*it, "role", std::string("user")); | ||||
|         std::string content = json_value(*it, "content", std::string("")); | ||||
|         if (role == "system") { | ||||
|             output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n"; | ||||
|             is_inside_turn = true; | ||||
|         } else if (role == "user") { | ||||
|             output << content << " [/INST]"; | ||||
|             is_inside_turn = true; | ||||
|         } else { | ||||
|             output << " " << content << " </s>"; | ||||
|             is_inside_turn = false; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     LOG_VERBOSE("format_llama2", {{"text", output.str()}}); | ||||
|  | ||||
|     return output.str(); | ||||
| } | ||||
|  | ||||
| inline std::string format_chatml(std::vector<json> messages) | ||||
| { | ||||
|     std::ostringstream chatml_msgs; | ||||
| @@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector<json> messages) | ||||
|  | ||||
|     chatml_msgs << "<|im_start|>assistant" << '\n'; | ||||
|  | ||||
|     LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}}); | ||||
|  | ||||
|     return chatml_msgs.str(); | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen