mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Add --jinja and --chat-template-file flags
This commit is contained in:
		
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @@ -1361,7 +1361,9 @@ llama-server: \ | |||||||
| 	examples/server/httplib.h \ | 	examples/server/httplib.h \ | ||||||
| 	examples/server/index.html.hpp \ | 	examples/server/index.html.hpp \ | ||||||
| 	examples/server/loading.html.hpp \ | 	examples/server/loading.html.hpp \ | ||||||
|  | 	common/chat-template.hpp \ | ||||||
| 	common/json.hpp \ | 	common/json.hpp \ | ||||||
|  | 	common/minja.hpp \ | ||||||
| 	$(OBJ_ALL) | 	$(OBJ_ALL) | ||||||
| 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | ||||||
| 	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) | 	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) | ||||||
|   | |||||||
| @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC | |||||||
|     arg.cpp |     arg.cpp | ||||||
|     arg.h |     arg.h | ||||||
|     base64.hpp |     base64.hpp | ||||||
|  |     chat-template.hpp | ||||||
|     common.cpp |     common.cpp | ||||||
|     common.h |     common.h | ||||||
|     console.cpp |     console.cpp | ||||||
| @@ -64,6 +65,7 @@ add_library(${TARGET} STATIC | |||||||
|     json.hpp |     json.hpp | ||||||
|     log.cpp |     log.cpp | ||||||
|     log.h |     log.h | ||||||
|  |     minja.hpp | ||||||
|     ngram-cache.cpp |     ngram-cache.cpp | ||||||
|     ngram-cache.h |     ngram-cache.h | ||||||
|     sampling.cpp |     sampling.cpp | ||||||
|   | |||||||
| @@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     ).set_examples({LLAMA_EXAMPLE_SERVER})); |     ).set_examples({LLAMA_EXAMPLE_SERVER})); | ||||||
|  |     add_opt(common_arg( | ||||||
|  |         {"--jinja"}, | ||||||
|  |         "use jinja template for chat (default: disabled)", | ||||||
|  |         [](common_params & params) { | ||||||
|  |             params.use_jinja = true; | ||||||
|  |         } | ||||||
|  |     ).set_examples({LLAMA_EXAMPLE_SERVER})); | ||||||
|     add_opt(common_arg( |     add_opt(common_arg( | ||||||
|         {"--chat-template"}, "JINJA_TEMPLATE", |         {"--chat-template"}, "JINJA_TEMPLATE", | ||||||
|         string_format( |         string_format( | ||||||
|             "set custom jinja chat template (default: template taken from model's metadata)\n" |             "set custom jinja chat template (default: template taken from model's metadata)\n" | ||||||
|             "if suffix/prefix are specified, template will be disabled\n" |             "if suffix/prefix are specified, template will be disabled\n" | ||||||
|  |             "only commonly used templates are accepted (unless --jinja is set before this flag):\n" | ||||||
|             "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() |             "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() | ||||||
|         ), |         ), | ||||||
|         [](common_params & params, const std::string & value) { |         [](common_params & params, const std::string & value) { | ||||||
|             if (!common_chat_verify_template(value)) { |             if (!common_chat_verify_template(value, params.use_jinja)) { | ||||||
|                 throw std::runtime_error(string_format( |                 throw std::runtime_error(string_format( | ||||||
|                     "error: the supplied chat template is not supported: %s\n" |                     "error: the supplied chat template is not supported: %s%s\n", | ||||||
|                     "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", |                     value.c_str(), | ||||||
|                     value.c_str() |                     params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" | ||||||
|                 )); |                 )); | ||||||
|             } |             } | ||||||
|             params.chat_template = value; |             params.chat_template = value; | ||||||
|         } |         } | ||||||
|     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); |     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); | ||||||
|  |     add_opt(common_arg( | ||||||
|  |         {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", | ||||||
|  |         "set custom jinja chat template file (default: template taken from model's metadata)\n" | ||||||
|  |         "if suffix/prefix are specified, template will be disabled\n" | ||||||
|  |         "only commonly used templates are accepted (unless --jinja is set before this flag):\n" | ||||||
|  |         "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", | ||||||
|  |         [](common_params & params, const std::string & value) { | ||||||
|  |             std::ifstream file(value); | ||||||
|  |             if (!file) { | ||||||
|  |                 throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); | ||||||
|  |             } | ||||||
|  |             std::string chat_template; | ||||||
|  |             std::copy( | ||||||
|  |                 std::istreambuf_iterator<char>(file), | ||||||
|  |                 std::istreambuf_iterator<char>(), | ||||||
|  |                 std::back_inserter(chat_template) | ||||||
|  |             ); | ||||||
|  |             if (!common_chat_verify_template(chat_template, params.use_jinja)) { | ||||||
|  |                 throw std::runtime_error(string_format( | ||||||
|  |                     "error: the supplied chat template is not supported: %s%s\n", | ||||||
|  |                     value.c_str(), | ||||||
|  |                     params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" | ||||||
|  |                 )); | ||||||
|  |             } | ||||||
|  |             params.chat_template = chat_template; | ||||||
|  |         } | ||||||
|  |     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); | ||||||
|     add_opt(common_arg( |     add_opt(common_arg( | ||||||
|         {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", |         {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", | ||||||
|         string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), |         string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), | ||||||
|   | |||||||
| @@ -1576,13 +1576,13 @@ std::vector<llama_token> common_tokenize( | |||||||
|     return result; |     return result; | ||||||
| } | } | ||||||
|  |  | ||||||
| std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { | static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) { | ||||||
|     std::string piece; |     std::string piece; | ||||||
|     piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n' |     piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n' | ||||||
|     const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); |     const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); | ||||||
|     if (n_chars < 0) { |     if (n_chars < 0) { | ||||||
|         piece.resize(-n_chars); |         piece.resize(-n_chars); | ||||||
|         int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); |         int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); | ||||||
|         GGML_ASSERT(check == -n_chars); |         GGML_ASSERT(check == -n_chars); | ||||||
|     } |     } | ||||||
|     else { |     else { | ||||||
| @@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token | |||||||
|     return piece; |     return piece; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { | ||||||
|  |     return _common_token_to_piece(llama_get_model(ctx), token, special); | ||||||
|  | } | ||||||
|  |  | ||||||
| std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { | std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { | ||||||
|     std::string text; |     std::string text; | ||||||
|     text.resize(std::max(text.capacity(), tokens.size())); |     text.resize(std::max(text.capacity(), tokens.size())); | ||||||
| @@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token | |||||||
| // Chat template utils | // Chat template utils | ||||||
| // | // | ||||||
|  |  | ||||||
| bool common_chat_verify_template(const std::string & tmpl) { | bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { | ||||||
|  |     if (use_jinja) { | ||||||
|  |         try { | ||||||
|  |             auto chat_template = minja::chat_template(tmpl, "<s>", "</s>"); | ||||||
|  |             chat_template.apply({{ | ||||||
|  |                 {"role", "user"}, | ||||||
|  |                 {"content", "test"}, | ||||||
|  |             }}, json(), true); | ||||||
|  |             return true; | ||||||
|  |         } catch (const std::exception & e) { | ||||||
|  |             LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     llama_chat_message chat[] = {{"user", "test"}}; |     llama_chat_message chat[] = {{"user", "test"}}; | ||||||
|     int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); |     int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); | ||||||
|     return res >= 0; |     return res >= 0; | ||||||
| @@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model, | |||||||
|     return common_chat_apply_template(model, tmpl, msgs, true); |     return common_chat_apply_template(model, tmpl, msgs, true); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { | ||||||
|  |     int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); | ||||||
|  |     if (tlen > 0) { | ||||||
|  |         std::vector<char> curr_tmpl_buf(tlen + 1, 0); | ||||||
|  |         if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { | ||||||
|  |             return std::string(curr_tmpl_buf.data(), tlen); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return ""; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) | ||||||
|  | { | ||||||
|  |     auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); | ||||||
|  |     auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); | ||||||
|  |     std::string default_template_src = chat_template_override; | ||||||
|  |     std::string tool_use_template_src = chat_template_override; | ||||||
|  |     if (chat_template_override.empty()) { | ||||||
|  |         default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template"); | ||||||
|  |         tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); | ||||||
|  |     } | ||||||
|  |     if (default_template_src.empty() || default_template_src == "chatml") { | ||||||
|  |         if (!tool_use_template_src.empty()) { | ||||||
|  |             default_template_src = tool_use_template_src; | ||||||
|  |         } else { | ||||||
|  |             default_template_src = R"( | ||||||
|  |                 {%- for message in messages -%} | ||||||
|  |                     {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} | ||||||
|  |                 {%- endfor -%} | ||||||
|  |                 {%- if add_generation_prompt -%} | ||||||
|  |                     {{- "<|im_start|>assistant\n" -}} | ||||||
|  |                 {%- endif -%} | ||||||
|  |             )"; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return { | ||||||
|  |         .default_template = { default_template_src, bos_token, eos_token }, | ||||||
|  |         .tool_use_template = tool_use_template_src.empty() ? std::nullopt | ||||||
|  |             : std::optional<minja::chat_template>({ tool_use_template_src, bos_token, eos_token }), | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
| // | // | ||||||
| // KV cache utils | // KV cache utils | ||||||
| // | // | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
|  | #include "chat-template.hpp" | ||||||
|  |  | ||||||
| #include <string> | #include <string> | ||||||
| #include <vector> | #include <vector> | ||||||
| @@ -324,6 +325,7 @@ struct common_params { | |||||||
|     std::string hostname      = "127.0.0.1"; |     std::string hostname      = "127.0.0.1"; | ||||||
|     std::string public_path   = "";                                                                         // NOLINT |     std::string public_path   = "";                                                                         // NOLINT | ||||||
|     std::string chat_template = "";                                                                         // NOLINT |     std::string chat_template = "";                                                                         // NOLINT | ||||||
|  |     bool use_jinja = false;                                                                                 // NOLINT | ||||||
|     bool enable_chat_template = true; |     bool enable_chat_template = true; | ||||||
|  |  | ||||||
|     std::vector<std::string> api_keys; |     std::vector<std::string> api_keys; | ||||||
| @@ -571,8 +573,8 @@ struct common_chat_msg { | |||||||
|     std::string content; |     std::string content; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid | // Check if the template is supported or not. Returns true if it's valid | ||||||
| bool common_chat_verify_template(const std::string & tmpl); | bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); | ||||||
|  |  | ||||||
| // CPP wrapper for llama_chat_apply_template | // CPP wrapper for llama_chat_apply_template | ||||||
| // If the built-in template is not supported, we default to chatml | // If the built-in template is not supported, we default to chatml | ||||||
| @@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model, | |||||||
| std::string common_chat_format_example(const struct llama_model * model, | std::string common_chat_format_example(const struct llama_model * model, | ||||||
|         const std::string & tmpl); |         const std::string & tmpl); | ||||||
|  |  | ||||||
|  |  | ||||||
|  | struct llama_chat_templates { | ||||||
|  |     minja::chat_template default_template; | ||||||
|  |     std::optional<minja::chat_template> tool_use_template; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); | ||||||
|  |  | ||||||
| // | // | ||||||
| // KV cache utils | // KV cache utils | ||||||
| // | // | ||||||
|   | |||||||
| @@ -129,7 +129,7 @@ The project is under active development, and we are [looking for feedback and co | |||||||
| | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | ||||||
| | `--grammar-file FNAME` | file to read grammar from | | | `--grammar-file FNAME` | file to read grammar from | | ||||||
| | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | | ||||||
|  | | `--jinja` | Enable experimental Jinja templating engine (needed for tool use) | | ||||||
|  |  | ||||||
| **Example-specific params** | **Example-specific params** | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1623,16 +1623,36 @@ struct server_context { | |||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     bool validate_model_chat_template() const { |     bool validate_model_chat_template(bool use_jinja) const { | ||||||
|  |         llama_chat_message chat[] = {{"user", "test"}}; | ||||||
|  |  | ||||||
|  |         if (use_jinja) { | ||||||
|  |             auto templates = llama_chat_templates_from_model(model, ""); | ||||||
|  |             try { | ||||||
|  |                 templates.default_template.apply({{ | ||||||
|  |                     {"role", "user"}, | ||||||
|  |                     {"content", "test"}, | ||||||
|  |                 }}, json(), true); | ||||||
|  |                 if (templates.tool_use_template) { | ||||||
|  |                     templates.tool_use_template->apply({{ | ||||||
|  |                         {"role", "user"}, | ||||||
|  |                         {"content", "test"}, | ||||||
|  |                     }}, json(), true); | ||||||
|  |                 } | ||||||
|  |                 return true; | ||||||
|  |             } catch (const std::exception & e) { | ||||||
|  |                 SRV_ERR("failed to apply template: %s\n", e.what()); | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|             std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes |             std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes | ||||||
|             std::string template_key = "tokenizer.chat_template"; |             std::string template_key = "tokenizer.chat_template"; | ||||||
|             int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); |             int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); | ||||||
|             if (res >= 0) { |             if (res >= 0) { | ||||||
|             llama_chat_message chat[] = {{"user", "test"}}; |  | ||||||
|                 std::string tmpl = std::string(model_template.data(), model_template.size()); |                 std::string tmpl = std::string(model_template.data(), model_template.size()); | ||||||
|                 int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); |                 int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); | ||||||
|                 return chat_res > 0; |                 return chat_res > 0; | ||||||
|             } |             } | ||||||
|  |         } | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -3476,15 +3496,30 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { |     std::mutex chat_templates_mutex; | ||||||
|  |     std::optional<llama_chat_templates> chat_templates; | ||||||
|  |  | ||||||
|  |     auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & { | ||||||
|  |         std::lock_guard<std::mutex> lock(chat_templates_mutex); | ||||||
|  |         if (!chat_templates) { | ||||||
|  |             chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); | ||||||
|  |         } | ||||||
|  |         return *chat_templates; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) { | ||||||
|         // this endpoint is publicly available, please only return what is safe to be exposed |         // this endpoint is publicly available, please only return what is safe to be exposed | ||||||
|  |         const auto & templates = get_chat_templates(); | ||||||
|         json data = { |         json data = { | ||||||
|             { "default_generation_settings", ctx_server.default_generation_settings_for_props }, |             { "default_generation_settings", ctx_server.default_generation_settings_for_props }, | ||||||
|             { "total_slots",                 ctx_server.params_base.n_parallel }, |             { "total_slots",                 ctx_server.params_base.n_parallel }, | ||||||
|             { "model_path",                  ctx_server.params_base.model }, |             { "model_path",                  ctx_server.params_base.model }, | ||||||
|             { "chat_template",               llama_get_chat_template(ctx_server.model) }, |             { "chat_template",               templates.default_template.source() }, | ||||||
|             { "build_info",                  build_info }, |             { "build_info",                  build_info }, | ||||||
|         }; |         }; | ||||||
|  |         if (ctx_server.params_base.use_jinja && templates.tool_use_template) { | ||||||
|  |             data["chat_template_tool_use"] = templates.tool_use_template->source(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         res_ok(res, data); |         res_ok(res, data); | ||||||
|     }; |     }; | ||||||
| @@ -3685,13 +3720,17 @@ int main(int argc, char ** argv) { | |||||||
|         return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); |         return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { |     const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, &get_chat_templates](const httplib::Request & req, httplib::Response & res) { | ||||||
|         if (ctx_server.params_base.embedding) { |         if (ctx_server.params_base.embedding) { | ||||||
|             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); |             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); |         auto body = json::parse(req.body); | ||||||
|  |         const auto & templates = get_chat_templates(); | ||||||
|  |         const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; | ||||||
|  |         json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja); | ||||||
|  |  | ||||||
|         return handle_completions_generic( |         return handle_completions_generic( | ||||||
|             SERVER_TASK_TYPE_COMPLETION, |             SERVER_TASK_TYPE_COMPLETION, | ||||||
|             data, |             data, | ||||||
| @@ -4111,7 +4150,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // if a custom chat template is not supplied, we will use the one that comes with the model (if any) |     // if a custom chat template is not supplied, we will use the one that comes with the model (if any) | ||||||
|     if (params.chat_template.empty()) { |     if (params.chat_template.empty()) { | ||||||
|         if (!ctx_server.validate_model_chat_template()) { |         if (!ctx_server.validate_model_chat_template(params.use_jinja)) { | ||||||
|             LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); |             LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); | ||||||
|             params.chat_template = "chatml"; |             params.chat_template = "chatml"; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -4,22 +4,24 @@ from utils import * | |||||||
|  |  | ||||||
| server = ServerPreset.tinyllama2() | server = ServerPreset.tinyllama2() | ||||||
|  |  | ||||||
|  | @pytest.fixture(autouse=True) | ||||||
| @pytest.fixture(scope="module", autouse=True) |  | ||||||
| def create_server(): | def create_server(): | ||||||
|     global server |     global server | ||||||
|     server = ServerPreset.tinyllama2() |     server = ServerPreset.tinyllama2() | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", |     "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja", | ||||||
|     [ |     [ | ||||||
|         (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), |         (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False), | ||||||
|         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), |         (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True), | ||||||
|  |         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False), | ||||||
|  |         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True), | ||||||
|     ] |     ] | ||||||
| ) | ) | ||||||
| def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): | def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja): | ||||||
|     global server |     global server | ||||||
|  |     server.jinja = jinja | ||||||
|     server.start() |     server.start() | ||||||
|     res = server.make_request("POST", "/chat/completions", data={ |     res = server.make_request("POST", "/chat/completions", data={ | ||||||
|         "model": model, |         "model": model, | ||||||
| @@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library(): | |||||||
|  |  | ||||||
| @pytest.mark.parametrize("response_format,n_predicted,re_content", [ | @pytest.mark.parametrize("response_format,n_predicted,re_content", [ | ||||||
|     ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), |     ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), | ||||||
|  |     ({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""), | ||||||
|     ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), |     ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), | ||||||
|     ({"type": "json_object"}, 10, "(\\{|John)+"), |     ({"type": "json_object"}, 10, "(\\{|John)+"), | ||||||
|     ({"type": "sound"}, 0, None), |     ({"type": "sound"}, 0, None), | ||||||
|   | |||||||
| @@ -68,8 +68,9 @@ class ServerProcess: | |||||||
|     pooling: str | None = None |     pooling: str | None = None | ||||||
|     draft: int | None = None |     draft: int | None = None | ||||||
|     api_key: str | None = None |     api_key: str | None = None | ||||||
|     response_format: str | None = None |  | ||||||
|     lora_files: List[str] | None = None |     lora_files: List[str] | None = None | ||||||
|  |     chat_template_file: str | None = None | ||||||
|  |     jinja: bool | None = None | ||||||
|     disable_ctx_shift: int | None = False |     disable_ctx_shift: int | None = False | ||||||
|     draft_min: int | None = None |     draft_min: int | None = None | ||||||
|     draft_max: int | None = None |     draft_max: int | None = None | ||||||
| @@ -154,6 +155,10 @@ class ServerProcess: | |||||||
|         if self.lora_files: |         if self.lora_files: | ||||||
|             for lora_file in self.lora_files: |             for lora_file in self.lora_files: | ||||||
|                 server_args.extend(["--lora", lora_file]) |                 server_args.extend(["--lora", lora_file]) | ||||||
|  |         if self.chat_template_file: | ||||||
|  |             server_args.extend(["--chat-template-file", self.chat_template_file]) | ||||||
|  |         if self.jinja: | ||||||
|  |             server_args.append("--jinja") | ||||||
|         if self.disable_ctx_shift: |         if self.disable_ctx_shift: | ||||||
|             server_args.extend(["--no-context-shift"]) |             server_args.extend(["--no-context-shift"]) | ||||||
|         if self.api_key: |         if self.api_key: | ||||||
|   | |||||||
| @@ -16,6 +16,8 @@ | |||||||
| // Change JSON_ASSERT from assert() to GGML_ASSERT: | // Change JSON_ASSERT from assert() to GGML_ASSERT: | ||||||
| #define JSON_ASSERT GGML_ASSERT | #define JSON_ASSERT GGML_ASSERT | ||||||
| #include "json.hpp" | #include "json.hpp" | ||||||
|  | #include "minja.hpp" | ||||||
|  | #include "chat-template.hpp" | ||||||
|  |  | ||||||
| #include <random> | #include <random> | ||||||
| #include <sstream> | #include <sstream> | ||||||
| @@ -382,19 +384,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri | |||||||
|     return formatted_chat; |     return formatted_chat; | ||||||
| } | } | ||||||
|  |  | ||||||
| static std::string llama_get_chat_template(const struct llama_model * model) { |  | ||||||
|     std::string template_key = "tokenizer.chat_template"; |  | ||||||
|     // call with NULL buffer to get the total size of the string |  | ||||||
|     int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); |  | ||||||
|     if (res < 2) { |  | ||||||
|         return ""; |  | ||||||
|     } else { |  | ||||||
|         std::vector<char> model_template(res + 1, 0); |  | ||||||
|         llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); |  | ||||||
|         return std::string(model_template.data(), model_template.size() - 1); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // | // | ||||||
| // base64 utils (TODO: move to common in the future) | // base64 utils (TODO: move to common in the future) | ||||||
| // | // | ||||||
| @@ -552,11 +541,21 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons | |||||||
| static json oaicompat_completion_params_parse( | static json oaicompat_completion_params_parse( | ||||||
|     const struct llama_model * model, |     const struct llama_model * model, | ||||||
|     const json & body, /* openai api json semantics */ |     const json & body, /* openai api json semantics */ | ||||||
|     const std::string & chat_template) { |     const minja::chat_template & tmpl, | ||||||
|  |     bool use_jinja) | ||||||
|  | { | ||||||
|     json llama_params; |     json llama_params; | ||||||
|  |  | ||||||
|     // Apply chat template to the list of messages |     auto tools = json_value(body, "tools", json()); | ||||||
|     llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); |     auto has_tools = tools.is_array() && !tools.empty(); | ||||||
|  |  | ||||||
|  |     if (has_tools) { | ||||||
|  |         if (use_jinja) { | ||||||
|  |             LOG_WRN("tools param is not fully supported yet\n"); | ||||||
|  |         } else { | ||||||
|  |             throw std::runtime_error("tools param requires --jinja flag"); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // Handle "stop" field |     // Handle "stop" field | ||||||
|     if (body.contains("stop") && body.at("stop").is_string()) { |     if (body.contains("stop") && body.at("stop").is_string()) { | ||||||
| @@ -579,6 +578,13 @@ static json oaicompat_completion_params_parse( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // Apply chat template to the list of messages | ||||||
|  |     if (use_jinja) { | ||||||
|  |         llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); | ||||||
|  |     } else { | ||||||
|  |         llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // Handle "n" field |     // Handle "n" field | ||||||
|     int n_choices = json_value(body, "n", 1); |     int n_choices = json_value(body, "n", 1); | ||||||
|     if (n_choices != 1) { |     if (n_choices != 1) { | ||||||
| @@ -594,7 +600,7 @@ static json oaicompat_completion_params_parse( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Params supported by OAI but unsupported by llama.cpp |     // Params supported by OAI but unsupported by llama.cpp | ||||||
|     static const std::vector<std::string> unsupported_params { "tools", "tool_choice" }; |     static const std::vector<std::string> unsupported_params { "tool_choice" }; | ||||||
|     for (const auto & param : unsupported_params) { |     for (const auto & param : unsupported_params) { | ||||||
|         if (body.contains(param)) { |         if (body.contains(param)) { | ||||||
|             throw std::runtime_error("Unsupported param: " + param); |             throw std::runtime_error("Unsupported param: " + param); | ||||||
|   | |||||||
							
								
								
									
										77
									
								
								scripts/get_hf_chat_template.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										77
									
								
								scripts/get_hf_chat_template.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,77 @@ | |||||||
|  | #!/usr/bin/env python | ||||||
|  | ''' | ||||||
|  |   Fetches the Jinja chat template of a HuggingFace model. | ||||||
|  |   If a model has multiple chat templates, you can specify the variant name. | ||||||
|  |  | ||||||
|  |   Syntax: | ||||||
|  |     ./scripts/get_hf_chat_template.py model_id [variant] | ||||||
|  |  | ||||||
|  |   Examples: | ||||||
|  |     ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct | ||||||
|  |     ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use | ||||||
|  |     ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct | ||||||
|  | ''' | ||||||
|  |  | ||||||
|  | import json | ||||||
|  | import re | ||||||
|  | import sys | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_hf_chat_template(model_id, variant=None): | ||||||
|  |     try: | ||||||
|  |         # Use huggingface_hub library if available. | ||||||
|  |         # Allows access to gated models if the user has access and ran `huggingface-cli login`. | ||||||
|  |         from huggingface_hub import hf_hub_download | ||||||
|  |         with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: | ||||||
|  |             config_str = f.read() | ||||||
|  |     except ImportError: | ||||||
|  |         import requests | ||||||
|  |         assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" | ||||||
|  |         response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") | ||||||
|  |         if response.status_code == 401: | ||||||
|  |             raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') | ||||||
|  |         response.raise_for_status() | ||||||
|  |         config_str = response.text | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         config = json.loads(config_str) | ||||||
|  |     except json.JSONDecodeError: | ||||||
|  |         # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json | ||||||
|  |         # (Remove extra '}' near the end of the file) | ||||||
|  |         config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) | ||||||
|  |  | ||||||
|  |     chat_template = config['chat_template'] | ||||||
|  |     if isinstance(chat_template, str): | ||||||
|  |         return chat_template | ||||||
|  |     else: | ||||||
|  |         variants = { | ||||||
|  |             ct['name']: ct['template'] | ||||||
|  |             for ct in chat_template | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         def format_variants(): | ||||||
|  |             return ', '.join(f'"{v}"' for v in variants.keys()) | ||||||
|  |  | ||||||
|  |         if variant is None: | ||||||
|  |             if 'default' not in variants: | ||||||
|  |                 raise Exception(f'Please specify a chat template variant (one of {format_variants()})') | ||||||
|  |             variant = 'default' | ||||||
|  |             print(f'Note: picked "default" chat template variant (out of {format_variants()})', file=sys.stderr) | ||||||
|  |         elif variant not in variants: | ||||||
|  |             raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") | ||||||
|  |  | ||||||
|  |         return variants[variant] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     if len(args) < 1: | ||||||
|  |         raise ValueError("Please provide a model ID and an optional variant name") | ||||||
|  |     model_id = args[0] | ||||||
|  |     variant = None if len(args) < 2 else args[1] | ||||||
|  |  | ||||||
|  |     template = get_hf_chat_template(model_id, variant) | ||||||
|  |     print(template, end=None) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main(sys.argv[1:]) | ||||||
| @@ -17,7 +17,7 @@ add_library(llama | |||||||
|             unicode-data.cpp |             unicode-data.cpp | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
| target_include_directories(llama PUBLIC . ../include) | target_include_directories(llama PUBLIC . ../include ../common) | ||||||
| target_compile_features   (llama PUBLIC cxx_std_17) # don't bump | target_compile_features   (llama PUBLIC cxx_std_17) # don't bump | ||||||
|  |  | ||||||
| target_link_libraries(llama PUBLIC ggml) | target_link_libraries(llama PUBLIC ggml) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 ochafik
					ochafik