mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	chat: Add LFM2 tool handling (#16763)
* Add LFM2 tool handling * fmt * Apply suggestion from @ykhrustalev
This commit is contained in:
		
							
								
								
									
										198
									
								
								common/chat.cpp
									
									
									
									
									
								
							
							
						
						
									
										198
									
								
								common/chat.cpp
									
									
									
									
									
								
							@@ -9,8 +9,11 @@
 | 
			
		||||
#include <minja/chat-template.hpp>
 | 
			
		||||
#include <minja/minja.hpp>
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <cctype>
 | 
			
		||||
#include <exception>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
@@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) {
 | 
			
		||||
        case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
 | 
			
		||||
        case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
 | 
			
		||||
        case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
 | 
			
		||||
        case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
 | 
			
		||||
        default:
 | 
			
		||||
            throw std::runtime_error("Unknown chat format");
 | 
			
		||||
    }
 | 
			
		||||
@@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
 | 
			
		||||
    return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Case-insensitive find
 | 
			
		||||
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
 | 
			
		||||
    auto it = std::search(
 | 
			
		||||
        haystack.begin() + pos, haystack.end(),
 | 
			
		||||
        needle.begin(), needle.end(),
 | 
			
		||||
        [](char a, char b) { return std::tolower(a) == std::tolower(b); }
 | 
			
		||||
    );
 | 
			
		||||
    return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
 | 
			
		||||
    common_chat_params data;
 | 
			
		||||
    const auto is_json_schema_provided = !inputs.json_schema.is_null();
 | 
			
		||||
    const auto is_grammar_provided = !inputs.grammar.empty();
 | 
			
		||||
    const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
 | 
			
		||||
 | 
			
		||||
    // the logic requires potentially modifying the messages
 | 
			
		||||
    auto tweaked_messages = inputs.messages;
 | 
			
		||||
 | 
			
		||||
    auto replace_json_schema_marker = [](json & messages) -> bool {
 | 
			
		||||
        static std::string marker1 = "force json schema.\n";
 | 
			
		||||
        static std::string marker2 = "force json schema.";
 | 
			
		||||
 | 
			
		||||
        if (messages.empty() || messages.at(0).at("role") != "system") {
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        std::string content = messages.at(0).at("content");
 | 
			
		||||
 | 
			
		||||
        for (const auto & marker : {marker1, marker2}) {
 | 
			
		||||
            const auto pos = ifind_string(content, marker);
 | 
			
		||||
            if (pos != std::string::npos) {
 | 
			
		||||
                content.replace(pos, marker.length(), "");
 | 
			
		||||
                // inject modified content back into the messages
 | 
			
		||||
                messages.at(0).at("content") = content;
 | 
			
		||||
                return true;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return false;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Lfm2 model does not natively work with json, but can generally understand the tools structure
 | 
			
		||||
    //
 | 
			
		||||
    // Example of the pytorch dialog structure:
 | 
			
		||||
    //     <|startoftext|><|im_start|>system
 | 
			
		||||
    //     List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
 | 
			
		||||
    //     <|im_start|>user
 | 
			
		||||
    //     What is the current status of candidate ID 12345?<|im_end|>
 | 
			
		||||
    //     <|im_start|>assistant
 | 
			
		||||
    //     <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
 | 
			
		||||
    //     <|im_start|>tool
 | 
			
		||||
    //     <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
 | 
			
		||||
    //     <|im_start|>assistant
 | 
			
		||||
    //     The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
 | 
			
		||||
    //
 | 
			
		||||
    // For the llama server compatibility with json tools semantic,
 | 
			
		||||
    // the client can add "Follow json schema." line into the system message prompt to force the json output.
 | 
			
		||||
    //
 | 
			
		||||
    if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
 | 
			
		||||
        // server/utils.hpp prohibits that branch for the custom grammar anyways
 | 
			
		||||
        throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
 | 
			
		||||
    } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
 | 
			
		||||
        LOG_INF("%s: Using tools to build a grammar\n", __func__);
 | 
			
		||||
 | 
			
		||||
        data.grammar = build_grammar([&](const common_grammar_builder & builder) {
 | 
			
		||||
            auto schemas = json::array();
 | 
			
		||||
            foreach_function(inputs.tools, [&](const json & tool) {
 | 
			
		||||
                const auto & function = tool.at("function");
 | 
			
		||||
                schemas.push_back({
 | 
			
		||||
                    {"type", "object"},
 | 
			
		||||
                    {"properties", {
 | 
			
		||||
                        {"name", {
 | 
			
		||||
                            {"type", "string"},
 | 
			
		||||
                            {"const", function.at("name")},
 | 
			
		||||
                        }},
 | 
			
		||||
                        {"arguments", function.at("parameters")},
 | 
			
		||||
                    }},
 | 
			
		||||
                    {"required", json::array({"name", "arguments", "id"})},
 | 
			
		||||
                });
 | 
			
		||||
            });
 | 
			
		||||
            auto schema = json {
 | 
			
		||||
                {"type", "array"},
 | 
			
		||||
                {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
 | 
			
		||||
                {"minItems", 1},
 | 
			
		||||
            };
 | 
			
		||||
            if (!inputs.parallel_tool_calls) {
 | 
			
		||||
                schema["maxItems"] = 1;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
 | 
			
		||||
        });
 | 
			
		||||
        // model has no concept of tool selection mode choice,
 | 
			
		||||
        // if the system prompt rendered correctly it will produce a tool call
 | 
			
		||||
        // the grammar goes inside the tool call body
 | 
			
		||||
        data.grammar_lazy = true;
 | 
			
		||||
        data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
 | 
			
		||||
        data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
 | 
			
		||||
        data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
 | 
			
		||||
    } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
 | 
			
		||||
        LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
 | 
			
		||||
        // output those tokens
 | 
			
		||||
        data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
 | 
			
		||||
    } else if (is_json_schema_provided) {
 | 
			
		||||
        LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
 | 
			
		||||
        data.grammar = json_schema_to_grammar(inputs.json_schema);
 | 
			
		||||
    } else if (is_grammar_provided) {
 | 
			
		||||
        LOG_INF("%s: Using provided grammar\n", __func__);
 | 
			
		||||
        data.grammar = inputs.grammar;
 | 
			
		||||
    } else {
 | 
			
		||||
        LOG_INF("%s: Using content relying on the template\n", __func__);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
 | 
			
		||||
    LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
 | 
			
		||||
 | 
			
		||||
    return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
 | 
			
		||||
    common_chat_params data;
 | 
			
		||||
    data.prompt = apply(tmpl, inputs);
 | 
			
		||||
@@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
 | 
			
		||||
    builder.add_content(builder.consume_rest());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
 | 
			
		||||
    if (!builder.syntax().parse_tool_calls) {
 | 
			
		||||
        builder.add_content(builder.consume_rest());
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
 | 
			
		||||
    static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
 | 
			
		||||
    static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
 | 
			
		||||
 | 
			
		||||
    // Loop through all tool calls
 | 
			
		||||
    while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
 | 
			
		||||
        builder.move_to(res->groups[0].end);
 | 
			
		||||
 | 
			
		||||
        // Parse JSON array format: [{"name": "...", "arguments": {...}}]
 | 
			
		||||
        auto tool_calls_data = builder.consume_json();
 | 
			
		||||
 | 
			
		||||
        // Consume end marker
 | 
			
		||||
        builder.consume_spaces();
 | 
			
		||||
        if (!builder.try_consume_regex(tool_call_end_regex)) {
 | 
			
		||||
            throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Process each tool call in the array
 | 
			
		||||
        if (tool_calls_data.json.is_array()) {
 | 
			
		||||
            for (const auto & tool_call : tool_calls_data.json) {
 | 
			
		||||
                if (!tool_call.is_object()) {
 | 
			
		||||
                    throw common_chat_msg_partial_exception("Tool call must be an object");
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (!tool_call.contains("name")) {
 | 
			
		||||
                    throw common_chat_msg_partial_exception("Tool call missing 'name' field");
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                std::string function_name = tool_call.at("name");
 | 
			
		||||
                std::string arguments = "{}";
 | 
			
		||||
 | 
			
		||||
                if (tool_call.contains("arguments")) {
 | 
			
		||||
                    if (tool_call.at("arguments").is_object()) {
 | 
			
		||||
                        arguments = tool_call.at("arguments").dump();
 | 
			
		||||
                    } else if (tool_call.at("arguments").is_string()) {
 | 
			
		||||
                        arguments = tool_call.at("arguments");
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (!builder.add_tool_call(function_name, "", arguments)) {
 | 
			
		||||
                    throw common_chat_msg_partial_exception("Incomplete tool call");
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Consume any trailing whitespace after this tool call
 | 
			
		||||
        builder.consume_spaces();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Consume any remaining content after all tool calls
 | 
			
		||||
    auto remaining = builder.consume_rest();
 | 
			
		||||
    if (!string_strip(remaining).empty()) {
 | 
			
		||||
        builder.add_content(remaining);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
 | 
			
		||||
    // Parse thinking tags first - this handles the main reasoning content
 | 
			
		||||
    builder.try_parse_reasoning("<seed:think>", "</seed:think>");
 | 
			
		||||
@@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja(
 | 
			
		||||
        return common_chat_params_init_apertus(tmpl, params);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // LFM2 (w/ tools)
 | 
			
		||||
    if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
 | 
			
		||||
        src.find("]<|tool_list_end|>") != std::string::npos) {
 | 
			
		||||
        return common_chat_params_init_lfm2(tmpl, params);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Use generic handler when mixing tools + JSON schema.
 | 
			
		||||
    // TODO: support that mix in handlers below.
 | 
			
		||||
    if ((params.tools.is_array() && params.json_schema.is_object())) {
 | 
			
		||||
@@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
 | 
			
		||||
        case COMMON_CHAT_FORMAT_APERTUS:
 | 
			
		||||
            common_chat_parse_apertus(builder);
 | 
			
		||||
            break;
 | 
			
		||||
        case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
 | 
			
		||||
            common_chat_parse_lfm2(builder);
 | 
			
		||||
            break;
 | 
			
		||||
        default:
 | 
			
		||||
            throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -116,6 +116,7 @@ enum common_chat_format {
 | 
			
		||||
    COMMON_CHAT_FORMAT_SEED_OSS,
 | 
			
		||||
    COMMON_CHAT_FORMAT_NEMOTRON_V2,
 | 
			
		||||
    COMMON_CHAT_FORMAT_APERTUS,
 | 
			
		||||
    COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
 | 
			
		||||
 | 
			
		||||
    COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user