mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Refactor common_chat_* functions to accept minja template + use_jinja option
This commit is contained in:
		| @@ -74,6 +74,15 @@ | |||||||
| #endif | #endif | ||||||
| #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 | #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 | ||||||
|  |  | ||||||
|  | const char * LLAMA_CHATML_TEMPLATE = R"( | ||||||
|  |     {%- for message in messages -%} | ||||||
|  |         {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} | ||||||
|  |     {%- endfor -%} | ||||||
|  |     {%- if add_generation_prompt -%} | ||||||
|  |         {{- "<|im_start|>assistant\n" -}} | ||||||
|  |     {%- endif -%} | ||||||
|  | )"; | ||||||
|  |  | ||||||
| // | // | ||||||
| // CURL utils | // CURL utils | ||||||
| // | // | ||||||
| @@ -1748,56 +1757,56 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { | |||||||
|     return res >= 0; |     return res >= 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| std::string common_chat_apply_template(const struct llama_model * model, | std::string common_chat_apply_template( | ||||||
|         const std::string & tmpl, |         const llama_chat_template & tmpl, | ||||||
|         const std::vector<common_chat_msg> & msgs, |         const std::vector<common_chat_msg> & msgs, | ||||||
|         bool add_ass) { |         bool add_ass, | ||||||
|  |         bool use_jinja) { | ||||||
|  |     if (use_jinja) { | ||||||
|  |         auto messages = json::array(); | ||||||
|  |         for (const auto & msg : msgs) { | ||||||
|  |             messages.push_back({{"role", msg.role}, {"content", msg.content}}); | ||||||
|  |         } | ||||||
|  |         return tmpl.apply(messages, /* tools= */ json(), add_ass); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     int alloc_size = 0; |     int alloc_size = 0; | ||||||
|     bool fallback = false; // indicate if we must fallback to default chatml |  | ||||||
|     std::vector<llama_chat_message> chat; |     std::vector<llama_chat_message> chat; | ||||||
|     for (const auto & msg : msgs) { |     for (const auto & msg : msgs) { | ||||||
|         chat.push_back({msg.role.c_str(), msg.content.c_str()}); |         chat.push_back({msg.role.c_str(), msg.content.c_str()}); | ||||||
|         alloc_size += (msg.role.size() + msg.content.size()) * 1.25; |         alloc_size += (msg.role.size() + msg.content.size()) * 1.25; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str(); |  | ||||||
|     std::vector<char> buf(alloc_size); |     std::vector<char> buf(alloc_size); | ||||||
|  |  | ||||||
|     // run the first time to get the total output length |     // run the first time to get the total output length | ||||||
|     int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); |     int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); | ||||||
|  |  | ||||||
|     // error: chat template is not supported |     // error: chat template is not supported | ||||||
|     if (res < 0) { |     if (res < 0) { | ||||||
|         if (ptr_tmpl != nullptr) { |  | ||||||
|         // if the custom "tmpl" is not supported, we throw an error |         // if the custom "tmpl" is not supported, we throw an error | ||||||
|         // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() |         // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() | ||||||
|         throw std::runtime_error("this custom template is not supported"); |         throw std::runtime_error("this custom template is not supported"); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|         // If the built-in template is not supported, we default to chatml |  | ||||||
|         res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); |  | ||||||
|         fallback = true; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // if it turns out that our buffer is too small, we resize it |     // if it turns out that our buffer is too small, we resize it | ||||||
|     if ((size_t) res > buf.size()) { |     if ((size_t) res > buf.size()) { | ||||||
|         buf.resize(res); |         buf.resize(res); | ||||||
|         res = llama_chat_apply_template( |         res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); | ||||||
|             fallback ? "chatml" : ptr_tmpl, |  | ||||||
|             chat.data(), chat.size(), add_ass, buf.data(), buf.size()); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     std::string formatted_chat(buf.data(), res); |     std::string formatted_chat(buf.data(), res); | ||||||
|     return formatted_chat; |     return formatted_chat; | ||||||
| } | } | ||||||
|  |  | ||||||
| std::string common_chat_format_single(const struct llama_model * model, | std::string common_chat_format_single( | ||||||
|         const std::string & tmpl, |         const llama_chat_template & tmpl, | ||||||
|         const std::vector<common_chat_msg> & past_msg, |         const std::vector<common_chat_msg> & past_msg, | ||||||
|         const common_chat_msg & new_msg, |         const common_chat_msg & new_msg, | ||||||
|         bool add_ass) { |         bool add_ass, | ||||||
|  |         bool use_jinja) { | ||||||
|     std::ostringstream ss; |     std::ostringstream ss; | ||||||
|     auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); |     auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); | ||||||
|     std::vector<common_chat_msg> chat_new(past_msg); |     std::vector<common_chat_msg> chat_new(past_msg); | ||||||
|     // if the past_msg ends with a newline, we must preserve it in the formatted version |     // if the past_msg ends with a newline, we must preserve it in the formatted version | ||||||
|     if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { |     if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { | ||||||
| @@ -1805,29 +1814,20 @@ std::string common_chat_format_single(const struct llama_model * model, | |||||||
|     }; |     }; | ||||||
|     // format chat with new_msg |     // format chat with new_msg | ||||||
|     chat_new.push_back(new_msg); |     chat_new.push_back(new_msg); | ||||||
|     auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); |     auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); | ||||||
|     // get the diff part |     // get the diff part | ||||||
|     ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); |     ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); | ||||||
|     return ss.str(); |     return ss.str(); | ||||||
| } | } | ||||||
|  |  | ||||||
| std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) { | std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) { | ||||||
|     std::vector<common_chat_msg> msgs = { |     std::vector<common_chat_msg> msgs = { | ||||||
|         {"system",    "You are a helpful assistant"}, |         {"system",    "You are a helpful assistant"}, | ||||||
|         {"user",      "Hello"}, |         {"user",      "Hello"}, | ||||||
|         {"assistant", "Hi there"}, |         {"assistant", "Hi there"}, | ||||||
|         {"user",      "How are you?"}, |         {"user",      "How are you?"}, | ||||||
|     }; |     }; | ||||||
|     const auto add_generation_prompt = true; |     return common_chat_apply_template(tmpl, msgs, true, use_jinja); | ||||||
|     if (use_jinja) { |  | ||||||
|         auto messages = json::array(); |  | ||||||
|         for (const auto & msg : msgs) { |  | ||||||
|             messages.push_back({{"role", msg.role}, {"content", msg.content}}); |  | ||||||
|         } |  | ||||||
|         return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt); |  | ||||||
|     } else { |  | ||||||
|         return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt); |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) | llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) | ||||||
| @@ -1847,14 +1847,7 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * | |||||||
|         if (!tool_use_template_src.empty()) { |         if (!tool_use_template_src.empty()) { | ||||||
|             default_template_src = tool_use_template_src; |             default_template_src = tool_use_template_src; | ||||||
|         } else { |         } else { | ||||||
|             default_template_src = R"( |             default_template_src = LLAMA_CHATML_TEMPLATE; | ||||||
|                 {%- for message in messages -%} |  | ||||||
|                     {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} |  | ||||||
|                 {%- endfor -%} |  | ||||||
|                 {%- if add_generation_prompt -%} |  | ||||||
|                     {{- "<|im_start|>assistant\n" -}} |  | ||||||
|                 {%- endif -%} |  | ||||||
|             )"; |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     return { |     return { | ||||||
|   | |||||||
| @@ -26,6 +26,8 @@ | |||||||
|  |  | ||||||
| #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" | #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" | ||||||
|  |  | ||||||
|  | extern const char * LLAMA_CHATML_TEMPLATE; | ||||||
|  |  | ||||||
| struct common_adapter_lora_info { | struct common_adapter_lora_info { | ||||||
|     std::string path; |     std::string path; | ||||||
|     float scale; |     float scale; | ||||||
| @@ -602,29 +604,32 @@ struct common_chat_msg { | |||||||
| // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid | // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid | ||||||
| bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); | bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); | ||||||
|  |  | ||||||
|  | typedef minja::chat_template llama_chat_template; | ||||||
|  |  | ||||||
| // CPP wrapper for llama_chat_apply_template | // CPP wrapper for llama_chat_apply_template | ||||||
| // If the built-in template is not supported, we default to chatml | // If the built-in template is not supported, we default to chatml | ||||||
| // If the custom "tmpl" is not supported, we throw an error | // If the custom "tmpl" is not supported, we throw an error | ||||||
| std::string common_chat_apply_template(const struct llama_model * model, | std::string common_chat_apply_template( | ||||||
|         const std::string & tmpl, |         const llama_chat_template & tmpl, | ||||||
|         const std::vector<common_chat_msg> & chat, |         const std::vector<common_chat_msg> & chat, | ||||||
|         bool add_ass); |         bool add_ass, | ||||||
|  |         bool use_jinja); | ||||||
|  |  | ||||||
| // Format single message, while taking into account the position of that message in chat history | // Format single message, while taking into account the position of that message in chat history | ||||||
| std::string common_chat_format_single(const struct llama_model * model, | std::string common_chat_format_single( | ||||||
|         const std::string & tmpl, |         const llama_chat_template & tmpl, | ||||||
|         const std::vector<common_chat_msg> & past_msg, |         const std::vector<common_chat_msg> & past_msg, | ||||||
|         const common_chat_msg & new_msg, |         const common_chat_msg & new_msg, | ||||||
|         bool add_ass); |         bool add_ass, | ||||||
|  |         bool use_jinja); | ||||||
|  |  | ||||||
| // Returns an example of formatted chat | // Returns an example of formatted chat | ||||||
| std::string common_chat_format_example(const struct llama_model * model, | std::string common_chat_format_example( | ||||||
|     const minja::chat_template & tmpl, bool use_jinja); |     const llama_chat_template & tmpl, bool use_jinja); | ||||||
|  |  | ||||||
|  |  | ||||||
| struct llama_chat_templates { | struct llama_chat_templates { | ||||||
|     minja::chat_template default_template; |     llama_chat_template default_template; | ||||||
|     std::optional<minja::chat_template> tool_use_template; |     std::optional<llama_chat_template> tool_use_template; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); | llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); | ||||||
|   | |||||||
| @@ -84,14 +84,6 @@ static void sigint_handler(int signo) { | |||||||
| } | } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) { |  | ||||||
|     common_chat_msg new_msg{role, content}; |  | ||||||
|     auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); |  | ||||||
|     chat_msgs.push_back({role, content}); |  | ||||||
|     LOG_DBG("formatted: '%s'\n", formatted.c_str()); |  | ||||||
|     return formatted; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| int main(int argc, char ** argv) { | int main(int argc, char ** argv) { | ||||||
|     common_params params; |     common_params params; | ||||||
|     g_params = ¶ms; |     g_params = ¶ms; | ||||||
| @@ -226,7 +218,7 @@ int main(int argc, char ** argv) { | |||||||
|     // print chat template example in conversation mode |     // print chat template example in conversation mode | ||||||
|     if (params.conversation_mode) { |     if (params.conversation_mode) { | ||||||
|         if (params.enable_chat_template) { |         if (params.enable_chat_template) { | ||||||
|             LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str()); |             LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str()); | ||||||
|         } else { |         } else { | ||||||
|             LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); |             LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); | ||||||
|         } |         } | ||||||
| @@ -270,10 +262,18 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     std::vector<llama_token> embd_inp; |     std::vector<llama_token> embd_inp; | ||||||
|  |  | ||||||
|  |     auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { | ||||||
|  |         common_chat_msg new_msg{role, content}; | ||||||
|  |         auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja); | ||||||
|  |         chat_msgs.push_back({role, content}); | ||||||
|  |         LOG_DBG("formatted: '%s'\n", formatted.c_str()); | ||||||
|  |         return formatted; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         auto prompt = (params.conversation_mode && params.enable_chat_template) |         auto prompt = (params.conversation_mode && params.enable_chat_template) | ||||||
|             // format the system prompt in conversation mode (fallback to default if empty) |             // format the system prompt in conversation mode (fallback to default if empty) | ||||||
|             ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) |             ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) | ||||||
|             // otherwise use the prompt as is |             // otherwise use the prompt as is | ||||||
|             : params.prompt; |             : params.prompt; | ||||||
|         if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { |         if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { | ||||||
| @@ -780,7 +780,7 @@ int main(int argc, char ** argv) { | |||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     if (params.enable_chat_template) { |                     if (params.enable_chat_template) { | ||||||
|                         chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); |                         chat_add_and_format("assistant", assistant_ss.str()); | ||||||
|                     } |                     } | ||||||
|                     is_interacting = true; |                     is_interacting = true; | ||||||
|                     LOG("\n"); |                     LOG("\n"); | ||||||
| @@ -845,7 +845,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                     bool format_chat = params.conversation_mode && params.enable_chat_template; |                     bool format_chat = params.conversation_mode && params.enable_chat_template; | ||||||
|                     std::string user_inp = format_chat |                     std::string user_inp = format_chat | ||||||
|                         ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) |                         ? chat_add_and_format("user", std::move(buffer)) | ||||||
|                         : std::move(buffer); |                         : std::move(buffer); | ||||||
|                     // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) |                     // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) | ||||||
|                     const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); |                     const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); | ||||||
|   | |||||||
| @@ -714,7 +714,7 @@ static void add_message(const char * role, const std::string & text, LlamaData & | |||||||
| } | } | ||||||
|  |  | ||||||
| // Function to apply the chat template and resize `formatted` if needed | // Function to apply the chat template and resize `formatted` if needed | ||||||
| static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { | static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { | ||||||
|     if (use_jinja) { |     if (use_jinja) { | ||||||
|         json messages = json::array(); |         json messages = json::array(); | ||||||
|         for (const auto & msg : llama_data.messages) { |         for (const auto & msg : llama_data.messages) { | ||||||
| @@ -868,7 +868,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, | |||||||
| } | } | ||||||
|  |  | ||||||
| // Helper function to apply the chat template and handle errors | // Helper function to apply the chat template and handle errors | ||||||
| static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { | static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { | ||||||
|     const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); |     const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); | ||||||
|     if (new_len < 0) { |     if (new_len < 0) { | ||||||
|         printe("failed to apply the chat template\n"); |         printe("failed to apply the chat template\n"); | ||||||
|   | |||||||
| @@ -3869,7 +3869,7 @@ int main(int argc, char ** argv) { | |||||||
|         auto body = json::parse(req.body); |         auto body = json::parse(req.body); | ||||||
|         const auto & templates = get_chat_templates(); |         const auto & templates = get_chat_templates(); | ||||||
|         const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; |         const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; | ||||||
|         json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja); |         json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); | ||||||
|  |  | ||||||
|         return handle_completions_impl( |         return handle_completions_impl( | ||||||
|             SERVER_TASK_TYPE_COMPLETION, |             SERVER_TASK_TYPE_COMPLETION, | ||||||
| @@ -4288,7 +4288,7 @@ int main(int argc, char ** argv) { | |||||||
|     // print sample chat example to make it clear which template is used |     // print sample chat example to make it clear which template is used | ||||||
|     LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, |     LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, | ||||||
|         get_chat_templates().default_template.source().c_str(), |         get_chat_templates().default_template.source().c_str(), | ||||||
|         common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); |         common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); | ||||||
|  |  | ||||||
|     ctx_server.queue_tasks.on_new_task(std::bind( |     ctx_server.queue_tasks.on_new_task(std::bind( | ||||||
|                 &server_context::process_single_task, &ctx_server, std::placeholders::_1)); |                 &server_context::process_single_task, &ctx_server, std::placeholders::_1)); | ||||||
|   | |||||||
| @@ -351,7 +351,7 @@ static llama_tokens format_infill( | |||||||
| } | } | ||||||
|  |  | ||||||
| // Format given chat. If tmpl is empty, we take the template from model metadata | // Format given chat. If tmpl is empty, we take the template from model metadata | ||||||
| inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) { | inline std::string format_chat(const llama_chat_template & tmpl, const std::vector<json> & messages) { | ||||||
|     std::vector<common_chat_msg> chat; |     std::vector<common_chat_msg> chat; | ||||||
|  |  | ||||||
|     for (size_t i = 0; i < messages.size(); ++i) { |     for (size_t i = 0; i < messages.size(); ++i) { | ||||||
| @@ -379,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri | |||||||
|         chat.push_back({role, content}); |         chat.push_back({role, content}); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); |     const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); | ||||||
|     LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); |     LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); | ||||||
|  |  | ||||||
|     return formatted_chat; |     return formatted_chat; | ||||||
| @@ -579,9 +579,8 @@ static json oaicompat_completion_params_parse(const json & body) { | |||||||
| } | } | ||||||
|  |  | ||||||
| static json oaicompat_completion_params_parse( | static json oaicompat_completion_params_parse( | ||||||
|     const struct llama_model * model, |  | ||||||
|     const json & body, /* openai api json semantics */ |     const json & body, /* openai api json semantics */ | ||||||
|     const minja::chat_template & tmpl, |     const llama_chat_template & tmpl, | ||||||
|     bool use_jinja) |     bool use_jinja) | ||||||
| { | { | ||||||
|     json llama_params; |     json llama_params; | ||||||
| @@ -622,7 +621,7 @@ static json oaicompat_completion_params_parse( | |||||||
|     if (use_jinja) { |     if (use_jinja) { | ||||||
|         llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); |         llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); | ||||||
|     } else { |     } else { | ||||||
|         llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); |         llama_params["prompt"] = format_chat(tmpl, body.at("messages")); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Handle "n" field |     // Handle "n" field | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ | |||||||
| #include "llama.h" | #include "llama.h" | ||||||
| #include "common.h" | #include "common.h" | ||||||
| #include "chat-template.hpp" | #include "chat-template.hpp" | ||||||
|  | #include "llama-chat.h" | ||||||
|  |  | ||||||
| int main(void) { | int main(void) { | ||||||
|     std::vector<llama_chat_message> conversation { |     std::vector<llama_chat_message> conversation { | ||||||
| @@ -319,9 +320,10 @@ int main(void) { | |||||||
|     std::vector<common_chat_msg> chat2; |     std::vector<common_chat_msg> chat2; | ||||||
|     common_chat_msg sys_msg{"system", "You are a helpful assistant"}; |     common_chat_msg sys_msg{"system", "You are a helpful assistant"}; | ||||||
|  |  | ||||||
|     auto fmt_sys = [&](std::string tmpl) { |     auto fmt_sys = [&](std::string tmpl_str) { | ||||||
|         auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); |         minja::chat_template tmpl(tmpl_str, "", ""); | ||||||
|         printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); |         auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); | ||||||
|  |         printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); | ||||||
|         printf("-------------------------\n"); |         printf("-------------------------\n"); | ||||||
|         return output; |         return output; | ||||||
|     }; |     }; | ||||||
| @@ -345,9 +347,10 @@ int main(void) { | |||||||
|     chat2.push_back({"assistant", "I am assistant"}); |     chat2.push_back({"assistant", "I am assistant"}); | ||||||
|     common_chat_msg new_msg{"user", "How are you"}; |     common_chat_msg new_msg{"user", "How are you"}; | ||||||
|  |  | ||||||
|     auto fmt_single = [&](std::string tmpl) { |     auto fmt_single = [&](std::string tmpl_str) { | ||||||
|         auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true); |         minja::chat_template tmpl(tmpl_str, "", ""); | ||||||
|         printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); |         auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); | ||||||
|  |         printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); | ||||||
|         printf("-------------------------\n"); |         printf("-------------------------\n"); | ||||||
|         return output; |         return output; | ||||||
|     }; |     }; | ||||||
| @@ -362,5 +365,7 @@ int main(void) { | |||||||
|     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); |     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); | ||||||
|     assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>"); |     assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>"); | ||||||
|  |  | ||||||
|  |     assert(llm_chat_detect_template(LLAMA_CHATML_TEMPLATE) == LLM_CHAT_TEMPLATE_CHATML); | ||||||
|  |  | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 ochafik
					ochafik