mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Server: use llama_chat_apply_template (#5593)
* server: use llama_chat_apply_template * server: remove trailing space * server: fix format_chat * server: fix help message Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: fix formatted_chat --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		| @@ -15,13 +15,11 @@ | |||||||
| using json = nlohmann::json; | using json = nlohmann::json; | ||||||
|  |  | ||||||
| inline static json oaicompat_completion_params_parse( | inline static json oaicompat_completion_params_parse( | ||||||
|  |     const struct llama_model * model, | ||||||
|     const json &body, /* openai api json semantics */ |     const json &body, /* openai api json semantics */ | ||||||
|     const std::string &chat_template) |     const std::string &chat_template) | ||||||
| { | { | ||||||
|     json llama_params; |     json llama_params; | ||||||
|     std::string formatted_prompt = chat_template == "chatml" |  | ||||||
|         ? format_chatml(body["messages"])  // OpenAI 'messages' to chatml (with <|im_start|>,...) |  | ||||||
|         : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...) |  | ||||||
|  |  | ||||||
|     llama_params["__oaicompat"] = true; |     llama_params["__oaicompat"] = true; | ||||||
|  |  | ||||||
| @@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse( | |||||||
|     // https://platform.openai.com/docs/api-reference/chat/create |     // https://platform.openai.com/docs/api-reference/chat/create | ||||||
|     llama_sampling_params default_sparams; |     llama_sampling_params default_sparams; | ||||||
|     llama_params["model"]             = json_value(body, "model", std::string("unknown")); |     llama_params["model"]             = json_value(body, "model", std::string("unknown")); | ||||||
|     llama_params["prompt"]            = formatted_prompt; |     llama_params["prompt"]            = format_chat(model, chat_template, body["messages"]); | ||||||
|     llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false); |     llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false); | ||||||
|     llama_params["temperature"]       = json_value(body, "temperature", 0.0); |     llama_params["temperature"]       = json_value(body, "temperature", 0.0); | ||||||
|     llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k); |     llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k); | ||||||
|   | |||||||
| @@ -37,7 +37,7 @@ struct server_params | |||||||
|     std::string hostname = "127.0.0.1"; |     std::string hostname = "127.0.0.1"; | ||||||
|     std::vector<std::string> api_keys; |     std::vector<std::string> api_keys; | ||||||
|     std::string public_path = "examples/server/public"; |     std::string public_path = "examples/server/public"; | ||||||
|     std::string chat_template = "chatml"; |     std::string chat_template = ""; | ||||||
|     int32_t port = 8080; |     int32_t port = 8080; | ||||||
|     int32_t read_timeout = 600; |     int32_t read_timeout = 600; | ||||||
|     int32_t write_timeout = 600; |     int32_t write_timeout = 600; | ||||||
| @@ -1937,8 +1937,9 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | |||||||
|     printf("                            types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); |     printf("                            types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); | ||||||
|     printf("  -gan N, --grp-attn-n N    set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); |     printf("  -gan N, --grp-attn-n N    set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); | ||||||
|     printf("  -gaw N, --grp-attn-w N    set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); |     printf("  -gaw N, --grp-attn-w N    set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); | ||||||
|     printf("  --chat-template FORMAT_NAME"); |     printf("  --chat-template JINJA_TEMPLATE\n"); | ||||||
|     printf("                            set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str()); |     printf("                            set custom jinja chat template (default: template taken from model's metadata)\n"); | ||||||
|  |     printf("                            Note: only commonly used templates are accepted, since we don't have jinja parser\n"); | ||||||
|     printf("\n"); |     printf("\n"); | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -2389,13 +2390,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |||||||
|                 invalid_param = true; |                 invalid_param = true; | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             std::string value(argv[i]); |             if (!verify_custom_template(argv[i])) { | ||||||
|             if (value != "chatml" && value != "llama2") { |                 fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); | ||||||
|                 fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); |                 fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); | ||||||
|                 invalid_param = true; |                 invalid_param = true; | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             sparams.chat_template = value; |             sparams.chat_template = argv[i]; | ||||||
|         } |         } | ||||||
|         else if (arg == "--override-kv") |         else if (arg == "--override-kv") | ||||||
|         { |         { | ||||||
| @@ -2913,7 +2914,7 @@ int main(int argc, char **argv) | |||||||
|                 if (!validate_api_key(req, res)) { |                 if (!validate_api_key(req, res)) { | ||||||
|                     return; |                     return; | ||||||
|                 } |                 } | ||||||
|                 json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); |                 json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); | ||||||
|  |  | ||||||
|                 const int task_id = llama.queue_tasks.get_new_id(); |                 const int task_id = llama.queue_tasks.get_new_id(); | ||||||
|                 llama.queue_results.add_waiting_task_id(task_id); |                 llama.queue_results.add_waiting_task_id(task_id); | ||||||
|   | |||||||
| @@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v | |||||||
|         : default_value; |         : default_value; | ||||||
| } | } | ||||||
|  |  | ||||||
| inline std::string format_llama2(std::vector<json> messages) | // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid | ||||||
| { | inline bool verify_custom_template(const std::string & tmpl) { | ||||||
|     std::ostringstream output; |     llama_chat_message chat[] = {{"user", "test"}}; | ||||||
|     bool is_inside_turn = false; |     std::vector<char> buf(1); | ||||||
|  |     int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size()); | ||||||
|     for (auto it = messages.begin(); it != messages.end(); ++it) { |     return res >= 0; | ||||||
|         if (!is_inside_turn) { |  | ||||||
|             output << "[INST] "; |  | ||||||
|         } |  | ||||||
|         std::string role    = json_value(*it, "role", std::string("user")); |  | ||||||
|         std::string content = json_value(*it, "content", std::string("")); |  | ||||||
|         if (role == "system") { |  | ||||||
|             output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n"; |  | ||||||
|             is_inside_turn = true; |  | ||||||
|         } else if (role == "user") { |  | ||||||
|             output << content << " [/INST]"; |  | ||||||
|             is_inside_turn = true; |  | ||||||
|         } else { |  | ||||||
|             output << " " << content << " </s>"; |  | ||||||
|             is_inside_turn = false; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     LOG_VERBOSE("format_llama2", {{"text", output.str()}}); |  | ||||||
|  |  | ||||||
|     return output.str(); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| inline std::string format_chatml(std::vector<json> messages) | // Format given chat. If tmpl is empty, we take the template from model metadata | ||||||
|  | inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) | ||||||
| { | { | ||||||
|     std::ostringstream chatml_msgs; |     size_t alloc_size = 0; | ||||||
|  |     // vector holding all allocated string to be passed to llama_chat_apply_template | ||||||
|  |     std::vector<std::string> str(messages.size() * 2); | ||||||
|  |     std::vector<llama_chat_message> chat(messages.size()); | ||||||
|  |  | ||||||
|     for (auto it = messages.begin(); it != messages.end(); ++it) { |     for (size_t i = 0; i < messages.size(); ++i) { | ||||||
|         chatml_msgs << "<|im_start|>" |         auto &curr_msg = messages[i]; | ||||||
|                     << json_value(*it, "role",    std::string("user")) << '\n'; |         str[i*2 + 0]    = json_value(curr_msg, "role",    std::string("")); | ||||||
|         chatml_msgs << json_value(*it, "content", std::string("")) |         str[i*2 + 1]    = json_value(curr_msg, "content", std::string("")); | ||||||
|                     << "<|im_end|>\n"; |         alloc_size     += str[i*2 + 1].length(); | ||||||
|  |         chat[i].role    = str[i*2 + 0].c_str(); | ||||||
|  |         chat[i].content = str[i*2 + 1].c_str(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     chatml_msgs << "<|im_start|>assistant" << '\n'; |     const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); | ||||||
|  |     std::vector<char> buf(alloc_size * 2); | ||||||
|  |  | ||||||
|     LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}}); |     // run the first time to get the total output length | ||||||
|  |     int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); | ||||||
|  |  | ||||||
|     return chatml_msgs.str(); |     // if it turns out that our buffer is too small, we resize it | ||||||
|  |     if ((size_t) res > buf.size()) { | ||||||
|  |         buf.resize(res); | ||||||
|  |         res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::string formatted_chat(buf.data(), res); | ||||||
|  |     LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); | ||||||
|  |  | ||||||
|  |     return formatted_chat; | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // | ||||||
|   | |||||||
| @@ -12602,7 +12602,7 @@ LLAMA_API int32_t llama_chat_apply_template( | |||||||
|         // load template from model |         // load template from model | ||||||
|         std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes |         std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes | ||||||
|         std::string template_key = "tokenizer.chat_template"; |         std::string template_key = "tokenizer.chat_template"; | ||||||
|         int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size()); |         int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); | ||||||
|         if (res < 0) { |         if (res < 0) { | ||||||
|             // worst case: there is no information about template, we will use chatml by default |             // worst case: there is no information about template, we will use chatml by default | ||||||
|             curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal |             curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen