mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	tool-call: support Command R7B (+ return tool_plan "thoughts" in API) (#11585)
				
					
				
			* `tool-call`: support Command R7B (w/ tool_plan return) * `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override * `tool-call`: test cleanup / handle lazy grammar triggers
This commit is contained in:
		| @@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) { | ||||
|         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; | ||||
|         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; | ||||
|         case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; | ||||
|         case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; | ||||
|         default: | ||||
|             throw std::runtime_error("Unknown chat format"); | ||||
|     } | ||||
| @@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) | ||||
|     return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); | ||||
| } | ||||
|  | ||||
| static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { | ||||
|     common_chat_params data; | ||||
|     data.grammar_lazy = inputs.tool_choice != "required"; | ||||
|     data.grammar = build_grammar([&](const common_grammar_builder & builder) { | ||||
|         auto schemas = json::array(); | ||||
|         foreach_function(inputs.tools, [&](const json & tool) { | ||||
|             const auto & function = tool["function"]; | ||||
|             schemas.push_back({ | ||||
|                 {"type", "object"}, | ||||
|                 {"properties", { | ||||
|                     {"tool_call_id", { | ||||
|                         {"type", "string"}, | ||||
|                         // Command-R's template expects an integer string. | ||||
|                         {"pattern", "^[0-9]{1,10}$"}, | ||||
|                     }}, | ||||
|                     {"tool_name", { | ||||
|                         {"type", "string"}, | ||||
|                         {"const", function["name"]}, | ||||
|                     }}, | ||||
|                     {"parameters", function["parameters"]}, | ||||
|                 }}, | ||||
|                 {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, | ||||
|             }); | ||||
|         }); | ||||
|         auto schema = json { | ||||
|             {"type", "array"}, | ||||
|             {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, | ||||
|             {"minItems", 1}, | ||||
|         }; | ||||
|         if (!inputs.parallel_tool_calls) { | ||||
|             schema["maxItems"] = 1; | ||||
|         } | ||||
|         builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); | ||||
|     }, grammar_options); | ||||
|     data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false}); | ||||
|     data.preserved_tokens = { | ||||
|         "<|START_RESPONSE|>", | ||||
|         "<|END_RESPONSE|>", | ||||
|         "<|START_THINKING|>", | ||||
|         "<|END_THINKING|>", | ||||
|         "<|END_ACTION|>", | ||||
|     }; | ||||
|     data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); | ||||
|     data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; | ||||
|     return data; | ||||
| } | ||||
| static common_chat_msg common_chat_parse_command_r7b(const std::string & input) { | ||||
|     static std::regex response_regex("<\\|START_RESPONSE\\|>(.*?)<\\|END_RESPONSE\\|>"); | ||||
|     static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); | ||||
|     std::smatch match; | ||||
|  | ||||
|     common_chat_msg result; | ||||
|     result.role = "assistant"; | ||||
|     if (std::regex_match(input, match, response_regex)) { | ||||
|         result.content = match[1].str(); | ||||
|     } else if (std::regex_match(input, match, thought_action_regex)) { | ||||
|         result.tool_plan = match[1].str(); | ||||
|         auto actions_str = match[2].str(); | ||||
|         auto actions = json::parse(actions_str); | ||||
|         for (const auto & action : actions) { | ||||
|             result.tool_calls.push_back({ | ||||
|                 /* .name = */      action["tool_name"], | ||||
|                 /* .arguments = */ action["parameters"].dump(), | ||||
|                 /* .id = */        action["tool_call_id"], | ||||
|             }); | ||||
|         } | ||||
|     } else { | ||||
|         LOG_ERR("Failed to parse command_r output"); | ||||
|         result.content = input; | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) { | ||||
|     if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { | ||||
|         throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); | ||||
| @@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ | ||||
|                 "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); | ||||
|         }); | ||||
|         data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); | ||||
|         data.preserved_tokens = { | ||||
|             "<|tool▁sep|>", | ||||
|             "<|tool▁call▁end|>", | ||||
|         }; | ||||
|         builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); | ||||
|     }, grammar_options); | ||||
|     data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); | ||||
| @@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat | ||||
|         auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space"; | ||||
|         builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); | ||||
|         data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false}); | ||||
|         // Not really a trigger but need to print this special token to get a successful parse. | ||||
|         data.grammar_triggers.push_back({"</tool_call>", /* .at_start = */ false}); | ||||
|         data.preserved_tokens = { "</tool_call>" }; | ||||
|     }, grammar_options); | ||||
|  | ||||
|     data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); | ||||
| @@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co | ||||
|     if (src.find("[TOOL_CALLS]") != std::string::npos) { | ||||
|         return common_chat_params_init_mistral_nemo(tmpl, inputs); | ||||
|     } | ||||
|     if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) { | ||||
|         return common_chat_params_init_command_r7b(tmpl, inputs); | ||||
|     } | ||||
|     return common_chat_params_init_generic(tmpl, inputs); | ||||
| } | ||||
|  | ||||
| @@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format | ||||
|             return common_chat_parse_hermes_2_pro(input); | ||||
|         case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: | ||||
|             return common_chat_parse_firefunction_v2(input); | ||||
|         case COMMON_CHAT_FORMAT_COMMAND_R7B: | ||||
|             return common_chat_parse_command_r7b(input); | ||||
|         default: | ||||
|             throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Olivier Chafik
					Olivier Chafik