diff --git a/common/chat.cpp b/common/chat.cpp index 8587140e1f..63583fb224 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -9,8 +9,11 @@ #include #include +#include #include +#include #include +#include #include #include #include @@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; default: throw std::runtime_error("Unknown chat format"); } @@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat return data; } + +// Case-insensitive find +static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { + auto it = std::search( + haystack.begin() + pos, haystack.end(), + needle.begin(), needle.end(), + [](char a, char b) { return std::tolower(a) == std::tolower(b); } + ); + return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); +} + +static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + const auto is_json_schema_provided = !inputs.json_schema.is_null(); + const auto is_grammar_provided = !inputs.grammar.empty(); + const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); + + // the logic requires potentially modifying the messages + auto tweaked_messages = inputs.messages; + + auto replace_json_schema_marker = [](json & messages) -> bool { + static std::string marker1 = "force json schema.\n"; + static std::string marker2 = "force json schema."; + + if (messages.empty() || messages.at(0).at("role") != "system") { + return false; + } + + std::string content = messages.at(0).at("content"); + + for (const auto & marker : {marker1, marker2}) { + const auto pos = ifind_string(content, marker); + if (pos != std::string::npos) { + content.replace(pos, marker.length(), ""); + // inject modified content back into the messages + messages.at(0).at("content") = content; + return true; + } + } + + return false; + }; + + // Lfm2 model does not natively work with json, but can generally understand the tools structure + // + // Example of the pytorch dialog structure: + // <|startoftext|><|im_start|>system + // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> + // <|im_start|>user + // What is the current status of candidate ID 12345?<|im_end|> + // <|im_start|>assistant + // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> + // <|im_start|>tool + // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> + // <|im_start|>assistant + // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> + // + // For the llama server compatibility with json tools semantic, + // the client can add "Follow json schema." line into the system message prompt to force the json output. + // + if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { + // server/utils.hpp prohibits that branch for the custom grammar anyways + throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); + } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { + LOG_INF("%s: Using tools to build a grammar\n", __func__); + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + + builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); + }); + // model has no concept of tool selection mode choice, + // if the system prompt rendered correctly it will produce a tool call + // the grammar goes inside the tool call body + data.grammar_lazy = true; + data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; + } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { + LOG_INF("%s: Using tools without json schema or grammar\n", __func__); + // output those tokens + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + } else if (is_json_schema_provided) { + LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else if (is_grammar_provided) { + LOG_INF("%s: Using provided grammar\n", __func__); + data.grammar = inputs.grammar; + } else { + LOG_INF("%s: Using content relying on the template\n", __func__); + } + + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); + LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); + + return data; +} + static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_apertus(tmpl, params); } + // LFM2 (w/ tools) + if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && + src.find("]<|tool_list_end|>") != std::string::npos) { + return common_chat_params_init_lfm2(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_APERTUS: common_chat_parse_apertus(builder); break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index f7b36ec711..50efb0d4e5 100644 --- a/common/chat.h +++ b/common/chat.h @@ -116,6 +116,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_NEMOTRON_V2, COMMON_CHAT_FORMAT_APERTUS, + COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/models/templates/llama-cpp-lfm2.jinja b/models/templates/llama-cpp-lfm2.jinja new file mode 100644 index 0000000000..b7921120bc --- /dev/null +++ b/models/templates/llama-cpp-lfm2.jinja @@ -0,0 +1,37 @@ +{{- bos_token -}} +{%- set system_prompt = "" -%} +{%- set ns = namespace(system_prompt="") -%} +{%- if messages[0]["role"] == "system" -%} + {%- set ns.system_prompt = messages[0]["content"] -%} + {%- set messages = messages[1:] -%} +{%- endif -%} +{%- if tools -%} + {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%} + {%- for tool in tools -%} + {%- if tool is not string -%} + {%- set tool = tool | tojson -%} + {%- endif -%} + {%- set ns.system_prompt = ns.system_prompt + tool -%} + {%- if not loop.last -%} + {%- set ns.system_prompt = ns.system_prompt + ", " -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%} +{%- endif -%} +{%- if ns.system_prompt -%} + {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}} +{%- endif -%} +{%- for message in messages -%} + {{- "<|im_start|>" + message["role"] + "\n" -}} + {%- set content = message["content"] -%} + {%- if content is not string -%} + {%- set content = content | tojson -%} + {%- endif -%} + {%- if message["role"] == "tool" -%} + {%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%} + {%- endif -%} + {{- content + "<|im_end|>\n" -}} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} +{%- endif -%} diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 52e23b5ac6..4a8ba849b3 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -16,6 +16,7 @@ #include #include +#include #include using json = nlohmann::ordered_json; @@ -2138,6 +2139,154 @@ static void test_template_output_parsers() { assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); } + { + // LFM2 format tests + auto tmpls = read_templates("models/templates/llama-cpp-lfm2.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + auto inputs_tools_forced_json_schema = std::invoke([&]() -> common_chat_templates_inputs { + common_chat_templates_inputs inputs; + inputs.messages = { + std::invoke([&]() -> common_chat_msg { + common_chat_msg msg; + msg.role = "system"; + msg.content = "force json schema.\n"; + return msg; + }), + message_user, + }; + inputs.tools = {special_function_tool}; + return inputs; + }); + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_no_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(true, params.grammar.empty()); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools_forced_json_schema); + assert_equals(COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, params.format); + assert_equals(true, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(false, params.grammar.empty()); + } + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test single tool call with JSON format + common_chat_msg msg_single_tool_call; + msg_single_tool_call.role = "assistant"; + msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""}); + assert_msg_equals( + msg_single_tool_call, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with string argument + common_chat_msg msg_tool_call_string; + msg_tool_call_string.role = "assistant"; + msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_string, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with multiple arguments + common_chat_msg msg_multi_args; + msg_multi_args.role = "assistant"; + msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""}); + assert_msg_equals( + msg_multi_args, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test multiple tool calls in single array + common_chat_msg msg_multiple_tools; + msg_multiple_tools.role = "assistant"; + msg_multiple_tools.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""}); + assert_msg_equals( + msg_multiple_tools, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content before + common_chat_msg msg_content_before_tool; + msg_content_before_tool.role = "assistant"; + msg_content_before_tool.content = "Let me check the weather for you."; + msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_before_tool, + common_chat_parse( + "Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content after + common_chat_msg msg_content_after_tool; + msg_content_after_tool.role = "assistant"; + msg_content_after_tool.content = "Here's the result."; + msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_after_tool, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with newlines (common in LLM output) + common_chat_msg msg_tool_call_newlines; + msg_tool_call_newlines.role = "assistant"; + msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_newlines, + common_chat_parse( + "<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Note: LFM2 uses JSON format for tool calls: [{"name": "...", "arguments": {...}}] + // Unlike other formats, LFM2 template does not render tool calls in conversation history, + // so we don't use test_templates() for tool call generation. Instead, the parsing tests + // above verify edge cases and format variations for the tool call output format. + } }