mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Add --jinja and --chat-template-file flags
This commit is contained in:
		
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @@ -1361,7 +1361,9 @@ llama-server: \ | ||||
| 	examples/server/httplib.h \ | ||||
| 	examples/server/index.html.hpp \ | ||||
| 	examples/server/loading.html.hpp \ | ||||
| 	common/chat-template.hpp \ | ||||
| 	common/json.hpp \ | ||||
| 	common/minja.hpp \ | ||||
| 	$(OBJ_ALL) | ||||
| 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | ||||
| 	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) | ||||
|   | ||||
| @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC | ||||
|     arg.cpp | ||||
|     arg.h | ||||
|     base64.hpp | ||||
|     chat-template.hpp | ||||
|     common.cpp | ||||
|     common.h | ||||
|     console.cpp | ||||
| @@ -64,6 +65,7 @@ add_library(${TARGET} STATIC | ||||
|     json.hpp | ||||
|     log.cpp | ||||
|     log.h | ||||
|     minja.hpp | ||||
|     ngram-cache.cpp | ||||
|     ngram-cache.h | ||||
|     sampling.cpp | ||||
|   | ||||
| @@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"--jinja"}, | ||||
|         "use jinja template for chat (default: disabled)", | ||||
|         [](common_params & params) { | ||||
|             params.use_jinja = true; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"--chat-template"}, "JINJA_TEMPLATE", | ||||
|         string_format( | ||||
|             "set custom jinja chat template (default: template taken from model's metadata)\n" | ||||
|             "if suffix/prefix are specified, template will be disabled\n" | ||||
|             "only commonly used templates are accepted (unless --jinja is set before this flag):\n" | ||||
|             "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() | ||||
|         ), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             if (!common_chat_verify_template(value)) { | ||||
|             if (!common_chat_verify_template(value, params.use_jinja)) { | ||||
|                 throw std::runtime_error(string_format( | ||||
|                     "error: the supplied chat template is not supported: %s\n" | ||||
|                     "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", | ||||
|                     value.c_str() | ||||
|                     "error: the supplied chat template is not supported: %s%s\n", | ||||
|                     value.c_str(), | ||||
|                     params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" | ||||
|                 )); | ||||
|             } | ||||
|             params.chat_template = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); | ||||
|     add_opt(common_arg( | ||||
|         {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", | ||||
|         "set custom jinja chat template file (default: template taken from model's metadata)\n" | ||||
|         "if suffix/prefix are specified, template will be disabled\n" | ||||
|         "only commonly used templates are accepted (unless --jinja is set before this flag):\n" | ||||
|         "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             std::ifstream file(value); | ||||
|             if (!file) { | ||||
|                 throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); | ||||
|             } | ||||
|             std::string chat_template; | ||||
|             std::copy( | ||||
|                 std::istreambuf_iterator<char>(file), | ||||
|                 std::istreambuf_iterator<char>(), | ||||
|                 std::back_inserter(chat_template) | ||||
|             ); | ||||
|             if (!common_chat_verify_template(chat_template, params.use_jinja)) { | ||||
|                 throw std::runtime_error(string_format( | ||||
|                     "error: the supplied chat template is not supported: %s%s\n", | ||||
|                     value.c_str(), | ||||
|                     params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" | ||||
|                 )); | ||||
|             } | ||||
|             params.chat_template = chat_template; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); | ||||
|     add_opt(common_arg( | ||||
|         {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", | ||||
|         string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), | ||||
|   | ||||
| @@ -1576,13 +1576,13 @@ std::vector<llama_token> common_tokenize( | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { | ||||
| static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) { | ||||
|     std::string piece; | ||||
|     piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n' | ||||
|     const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); | ||||
|     const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); | ||||
|     if (n_chars < 0) { | ||||
|         piece.resize(-n_chars); | ||||
|         int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); | ||||
|         int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); | ||||
|         GGML_ASSERT(check == -n_chars); | ||||
|     } | ||||
|     else { | ||||
| @@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token | ||||
|     return piece; | ||||
| } | ||||
|  | ||||
| std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { | ||||
|     return _common_token_to_piece(llama_get_model(ctx), token, special); | ||||
| } | ||||
|  | ||||
| std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { | ||||
|     std::string text; | ||||
|     text.resize(std::max(text.capacity(), tokens.size())); | ||||
| @@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token | ||||
| // Chat template utils | ||||
| // | ||||
|  | ||||
| bool common_chat_verify_template(const std::string & tmpl) { | ||||
| bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { | ||||
|     if (use_jinja) { | ||||
|         try { | ||||
|             auto chat_template = minja::chat_template(tmpl, "<s>", "</s>"); | ||||
|             chat_template.apply({{ | ||||
|                 {"role", "user"}, | ||||
|                 {"content", "test"}, | ||||
|             }}, json(), true); | ||||
|             return true; | ||||
|         } catch (const std::exception & e) { | ||||
|             LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); | ||||
|             return false; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     llama_chat_message chat[] = {{"user", "test"}}; | ||||
|     int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); | ||||
|     return res >= 0; | ||||
| @@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model, | ||||
|     return common_chat_apply_template(model, tmpl, msgs, true); | ||||
| } | ||||
|  | ||||
| static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { | ||||
|     int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); | ||||
|     if (tlen > 0) { | ||||
|         std::vector<char> curr_tmpl_buf(tlen + 1, 0); | ||||
|         if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { | ||||
|             return std::string(curr_tmpl_buf.data(), tlen); | ||||
|         } | ||||
|     } | ||||
|     return ""; | ||||
| } | ||||
|  | ||||
| llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) | ||||
| { | ||||
|     auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); | ||||
|     auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); | ||||
|     std::string default_template_src = chat_template_override; | ||||
|     std::string tool_use_template_src = chat_template_override; | ||||
|     if (chat_template_override.empty()) { | ||||
|         default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template"); | ||||
|         tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); | ||||
|     } | ||||
|     if (default_template_src.empty() || default_template_src == "chatml") { | ||||
|         if (!tool_use_template_src.empty()) { | ||||
|             default_template_src = tool_use_template_src; | ||||
|         } else { | ||||
|             default_template_src = R"( | ||||
|                 {%- for message in messages -%} | ||||
|                     {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} | ||||
|                 {%- endfor -%} | ||||
|                 {%- if add_generation_prompt -%} | ||||
|                     {{- "<|im_start|>assistant\n" -}} | ||||
|                 {%- endif -%} | ||||
|             )"; | ||||
|         } | ||||
|     } | ||||
|     return { | ||||
|         .default_template = { default_template_src, bos_token, eos_token }, | ||||
|         .tool_use_template = tool_use_template_src.empty() ? std::nullopt | ||||
|             : std::optional<minja::chat_template>({ tool_use_template_src, bos_token, eos_token }), | ||||
|     }; | ||||
| } | ||||
|  | ||||
| // | ||||
| // KV cache utils | ||||
| // | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "llama.h" | ||||
| #include "chat-template.hpp" | ||||
|  | ||||
| #include <string> | ||||
| #include <vector> | ||||
| @@ -324,6 +325,7 @@ struct common_params { | ||||
|     std::string hostname      = "127.0.0.1"; | ||||
|     std::string public_path   = "";                                                                         // NOLINT | ||||
|     std::string chat_template = "";                                                                         // NOLINT | ||||
|     bool use_jinja = false;                                                                                 // NOLINT | ||||
|     bool enable_chat_template = true; | ||||
|  | ||||
|     std::vector<std::string> api_keys; | ||||
| @@ -571,8 +573,8 @@ struct common_chat_msg { | ||||
|     std::string content; | ||||
| }; | ||||
|  | ||||
| // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid | ||||
| bool common_chat_verify_template(const std::string & tmpl); | ||||
| // Check if the template is supported or not. Returns true if it's valid | ||||
| bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); | ||||
|  | ||||
| // CPP wrapper for llama_chat_apply_template | ||||
| // If the built-in template is not supported, we default to chatml | ||||
| @@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model, | ||||
| std::string common_chat_format_example(const struct llama_model * model, | ||||
|         const std::string & tmpl); | ||||
|  | ||||
|  | ||||
| struct llama_chat_templates { | ||||
|     minja::chat_template default_template; | ||||
|     std::optional<minja::chat_template> tool_use_template; | ||||
| }; | ||||
|  | ||||
| llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); | ||||
|  | ||||
| // | ||||
| // KV cache utils | ||||
| // | ||||
|   | ||||
| @@ -129,7 +129,7 @@ The project is under active development, and we are [looking for feedback and co | ||||
| | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | ||||
| | `--grammar-file FNAME` | file to read grammar from | | ||||
| | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | | ||||
|  | ||||
| | `--jinja` | Enable experimental Jinja templating engine (needed for tool use) | | ||||
|  | ||||
| **Example-specific params** | ||||
|  | ||||
|   | ||||
| @@ -1623,15 +1623,35 @@ struct server_context { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     bool validate_model_chat_template() const { | ||||
|         std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes | ||||
|         std::string template_key = "tokenizer.chat_template"; | ||||
|         int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); | ||||
|         if (res >= 0) { | ||||
|             llama_chat_message chat[] = {{"user", "test"}}; | ||||
|             std::string tmpl = std::string(model_template.data(), model_template.size()); | ||||
|             int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); | ||||
|             return chat_res > 0; | ||||
|     bool validate_model_chat_template(bool use_jinja) const { | ||||
|         llama_chat_message chat[] = {{"user", "test"}}; | ||||
|  | ||||
|         if (use_jinja) { | ||||
|             auto templates = llama_chat_templates_from_model(model, ""); | ||||
|             try { | ||||
|                 templates.default_template.apply({{ | ||||
|                     {"role", "user"}, | ||||
|                     {"content", "test"}, | ||||
|                 }}, json(), true); | ||||
|                 if (templates.tool_use_template) { | ||||
|                     templates.tool_use_template->apply({{ | ||||
|                         {"role", "user"}, | ||||
|                         {"content", "test"}, | ||||
|                     }}, json(), true); | ||||
|                 } | ||||
|                 return true; | ||||
|             } catch (const std::exception & e) { | ||||
|                 SRV_ERR("failed to apply template: %s\n", e.what()); | ||||
|             } | ||||
|         } else { | ||||
|             std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes | ||||
|             std::string template_key = "tokenizer.chat_template"; | ||||
|             int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); | ||||
|             if (res >= 0) { | ||||
|                 std::string tmpl = std::string(model_template.data(), model_template.size()); | ||||
|                 int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); | ||||
|                 return chat_res > 0; | ||||
|             } | ||||
|         } | ||||
|         return false; | ||||
|     } | ||||
| @@ -3476,15 +3496,30 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { | ||||
|     std::mutex chat_templates_mutex; | ||||
|     std::optional<llama_chat_templates> chat_templates; | ||||
|  | ||||
|     auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & { | ||||
|         std::lock_guard<std::mutex> lock(chat_templates_mutex); | ||||
|         if (!chat_templates) { | ||||
|             chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); | ||||
|         } | ||||
|         return *chat_templates; | ||||
|     }; | ||||
|  | ||||
|     const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) { | ||||
|         // this endpoint is publicly available, please only return what is safe to be exposed | ||||
|         const auto & templates = get_chat_templates(); | ||||
|         json data = { | ||||
|             { "default_generation_settings", ctx_server.default_generation_settings_for_props }, | ||||
|             { "total_slots",                 ctx_server.params_base.n_parallel }, | ||||
|             { "model_path",                  ctx_server.params_base.model }, | ||||
|             { "chat_template",               llama_get_chat_template(ctx_server.model) }, | ||||
|             { "chat_template",               templates.default_template.source() }, | ||||
|             { "build_info",                  build_info }, | ||||
|         }; | ||||
|         if (ctx_server.params_base.use_jinja && templates.tool_use_template) { | ||||
|             data["chat_template_tool_use"] = templates.tool_use_template->source(); | ||||
|         } | ||||
|  | ||||
|         res_ok(res, data); | ||||
|     }; | ||||
| @@ -3685,13 +3720,17 @@ int main(int argc, char ** argv) { | ||||
|         return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); | ||||
|     }; | ||||
|  | ||||
|     const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { | ||||
|     const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, &get_chat_templates](const httplib::Request & req, httplib::Response & res) { | ||||
|         if (ctx_server.params_base.embedding) { | ||||
|             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); | ||||
|         auto body = json::parse(req.body); | ||||
|         const auto & templates = get_chat_templates(); | ||||
|         const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; | ||||
|         json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja); | ||||
|  | ||||
|         return handle_completions_generic( | ||||
|             SERVER_TASK_TYPE_COMPLETION, | ||||
|             data, | ||||
| @@ -4111,7 +4150,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // if a custom chat template is not supplied, we will use the one that comes with the model (if any) | ||||
|     if (params.chat_template.empty()) { | ||||
|         if (!ctx_server.validate_model_chat_template()) { | ||||
|         if (!ctx_server.validate_model_chat_template(params.use_jinja)) { | ||||
|             LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); | ||||
|             params.chat_template = "chatml"; | ||||
|         } | ||||
|   | ||||
| @@ -4,22 +4,24 @@ from utils import * | ||||
|  | ||||
| server = ServerPreset.tinyllama2() | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="module", autouse=True) | ||||
| @pytest.fixture(autouse=True) | ||||
| def create_server(): | ||||
|     global server | ||||
|     server = ServerPreset.tinyllama2() | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", | ||||
|     "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja", | ||||
|     [ | ||||
|         (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), | ||||
|         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), | ||||
|         (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False), | ||||
|         (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True), | ||||
|         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False), | ||||
|         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True), | ||||
|     ] | ||||
| ) | ||||
| def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): | ||||
| def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja): | ||||
|     global server | ||||
|     server.jinja = jinja | ||||
|     server.start() | ||||
|     res = server.make_request("POST", "/chat/completions", data={ | ||||
|         "model": model, | ||||
| @@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library(): | ||||
|  | ||||
| @pytest.mark.parametrize("response_format,n_predicted,re_content", [ | ||||
|     ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), | ||||
|     ({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""), | ||||
|     ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), | ||||
|     ({"type": "json_object"}, 10, "(\\{|John)+"), | ||||
|     ({"type": "sound"}, 0, None), | ||||
|   | ||||
| @@ -68,8 +68,9 @@ class ServerProcess: | ||||
|     pooling: str | None = None | ||||
|     draft: int | None = None | ||||
|     api_key: str | None = None | ||||
|     response_format: str | None = None | ||||
|     lora_files: List[str] | None = None | ||||
|     chat_template_file: str | None = None | ||||
|     jinja: bool | None = None | ||||
|     disable_ctx_shift: int | None = False | ||||
|     draft_min: int | None = None | ||||
|     draft_max: int | None = None | ||||
| @@ -154,6 +155,10 @@ class ServerProcess: | ||||
|         if self.lora_files: | ||||
|             for lora_file in self.lora_files: | ||||
|                 server_args.extend(["--lora", lora_file]) | ||||
|         if self.chat_template_file: | ||||
|             server_args.extend(["--chat-template-file", self.chat_template_file]) | ||||
|         if self.jinja: | ||||
|             server_args.append("--jinja") | ||||
|         if self.disable_ctx_shift: | ||||
|             server_args.extend(["--no-context-shift"]) | ||||
|         if self.api_key: | ||||
|   | ||||
| @@ -16,6 +16,8 @@ | ||||
| // Change JSON_ASSERT from assert() to GGML_ASSERT: | ||||
| #define JSON_ASSERT GGML_ASSERT | ||||
| #include "json.hpp" | ||||
| #include "minja.hpp" | ||||
| #include "chat-template.hpp" | ||||
|  | ||||
| #include <random> | ||||
| #include <sstream> | ||||
| @@ -382,19 +384,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri | ||||
|     return formatted_chat; | ||||
| } | ||||
|  | ||||
| static std::string llama_get_chat_template(const struct llama_model * model) { | ||||
|     std::string template_key = "tokenizer.chat_template"; | ||||
|     // call with NULL buffer to get the total size of the string | ||||
|     int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); | ||||
|     if (res < 2) { | ||||
|         return ""; | ||||
|     } else { | ||||
|         std::vector<char> model_template(res + 1, 0); | ||||
|         llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); | ||||
|         return std::string(model_template.data(), model_template.size() - 1); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // | ||||
| // base64 utils (TODO: move to common in the future) | ||||
| // | ||||
| @@ -552,11 +541,21 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons | ||||
| static json oaicompat_completion_params_parse( | ||||
|     const struct llama_model * model, | ||||
|     const json & body, /* openai api json semantics */ | ||||
|     const std::string & chat_template) { | ||||
|     const minja::chat_template & tmpl, | ||||
|     bool use_jinja) | ||||
| { | ||||
|     json llama_params; | ||||
|  | ||||
|     // Apply chat template to the list of messages | ||||
|     llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); | ||||
|     auto tools = json_value(body, "tools", json()); | ||||
|     auto has_tools = tools.is_array() && !tools.empty(); | ||||
|  | ||||
|     if (has_tools) { | ||||
|         if (use_jinja) { | ||||
|             LOG_WRN("tools param is not fully supported yet\n"); | ||||
|         } else { | ||||
|             throw std::runtime_error("tools param requires --jinja flag"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Handle "stop" field | ||||
|     if (body.contains("stop") && body.at("stop").is_string()) { | ||||
| @@ -579,6 +578,13 @@ static json oaicompat_completion_params_parse( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Apply chat template to the list of messages | ||||
|     if (use_jinja) { | ||||
|         llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); | ||||
|     } else { | ||||
|         llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); | ||||
|     } | ||||
|  | ||||
|     // Handle "n" field | ||||
|     int n_choices = json_value(body, "n", 1); | ||||
|     if (n_choices != 1) { | ||||
| @@ -594,7 +600,7 @@ static json oaicompat_completion_params_parse( | ||||
|     } | ||||
|  | ||||
|     // Params supported by OAI but unsupported by llama.cpp | ||||
|     static const std::vector<std::string> unsupported_params { "tools", "tool_choice" }; | ||||
|     static const std::vector<std::string> unsupported_params { "tool_choice" }; | ||||
|     for (const auto & param : unsupported_params) { | ||||
|         if (body.contains(param)) { | ||||
|             throw std::runtime_error("Unsupported param: " + param); | ||||
|   | ||||
							
								
								
									
										77
									
								
								scripts/get_hf_chat_template.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										77
									
								
								scripts/get_hf_chat_template.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| #!/usr/bin/env python | ||||
| ''' | ||||
|   Fetches the Jinja chat template of a HuggingFace model. | ||||
|   If a model has multiple chat templates, you can specify the variant name. | ||||
|  | ||||
|   Syntax: | ||||
|     ./scripts/get_hf_chat_template.py model_id [variant] | ||||
|  | ||||
|   Examples: | ||||
|     ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct | ||||
|     ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use | ||||
|     ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct | ||||
| ''' | ||||
|  | ||||
| import json | ||||
| import re | ||||
| import sys | ||||
|  | ||||
|  | ||||
| def get_hf_chat_template(model_id, variant=None): | ||||
|     try: | ||||
|         # Use huggingface_hub library if available. | ||||
|         # Allows access to gated models if the user has access and ran `huggingface-cli login`. | ||||
|         from huggingface_hub import hf_hub_download | ||||
|         with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: | ||||
|             config_str = f.read() | ||||
|     except ImportError: | ||||
|         import requests | ||||
|         assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" | ||||
|         response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") | ||||
|         if response.status_code == 401: | ||||
|             raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') | ||||
|         response.raise_for_status() | ||||
|         config_str = response.text | ||||
|  | ||||
|     try: | ||||
|         config = json.loads(config_str) | ||||
|     except json.JSONDecodeError: | ||||
|         # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json | ||||
|         # (Remove extra '}' near the end of the file) | ||||
|         config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) | ||||
|  | ||||
|     chat_template = config['chat_template'] | ||||
|     if isinstance(chat_template, str): | ||||
|         return chat_template | ||||
|     else: | ||||
|         variants = { | ||||
|             ct['name']: ct['template'] | ||||
|             for ct in chat_template | ||||
|         } | ||||
|  | ||||
|         def format_variants(): | ||||
|             return ', '.join(f'"{v}"' for v in variants.keys()) | ||||
|  | ||||
|         if variant is None: | ||||
|             if 'default' not in variants: | ||||
|                 raise Exception(f'Please specify a chat template variant (one of {format_variants()})') | ||||
|             variant = 'default' | ||||
|             print(f'Note: picked "default" chat template variant (out of {format_variants()})', file=sys.stderr) | ||||
|         elif variant not in variants: | ||||
|             raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") | ||||
|  | ||||
|         return variants[variant] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     if len(args) < 1: | ||||
|         raise ValueError("Please provide a model ID and an optional variant name") | ||||
|     model_id = args[0] | ||||
|     variant = None if len(args) < 2 else args[1] | ||||
|  | ||||
|     template = get_hf_chat_template(model_id, variant) | ||||
|     print(template, end=None) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main(sys.argv[1:]) | ||||
| @@ -17,7 +17,7 @@ add_library(llama | ||||
|             unicode-data.cpp | ||||
|             ) | ||||
|  | ||||
| target_include_directories(llama PUBLIC . ../include) | ||||
| target_include_directories(llama PUBLIC . ../include ../common) | ||||
| target_compile_features   (llama PUBLIC cxx_std_17) # don't bump | ||||
|  | ||||
| target_link_libraries(llama PUBLIC ggml) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 ochafik
					ochafik