mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
vendor : sync minja (#16500)
* sync minja.hpp Adds Call/EndCall support, used in MiniCPM3 and MiniCPM4-MCP. * remove spurious semicolon * sync from ochafik/minja
This commit is contained in:
111
vendor/minja/minja.hpp
vendored
111
vendor/minja/minja.hpp
vendored
@@ -55,7 +55,7 @@ inline std::string normalize_newlines(const std::string & s) {
|
||||
}
|
||||
|
||||
/* Values that behave roughly like in Python. */
|
||||
class Value : public std::enable_shared_from_this<Value> {
|
||||
class Value {
|
||||
public:
|
||||
using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
|
||||
using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
|
||||
@@ -158,12 +158,14 @@ public:
|
||||
Value(const json & v) {
|
||||
if (v.is_object()) {
|
||||
auto object = std::make_shared<ObjectType>();
|
||||
object->reserve(v.size());
|
||||
for (auto it = v.begin(); it != v.end(); ++it) {
|
||||
(*object)[it.key()] = it.value();
|
||||
object->emplace_back(it.key(), Value(it.value()));
|
||||
}
|
||||
object_ = std::move(object);
|
||||
} else if (v.is_array()) {
|
||||
auto array = std::make_shared<ArrayType>();
|
||||
array->reserve(v.size());
|
||||
for (const auto& item : v) {
|
||||
array->push_back(Value(item));
|
||||
}
|
||||
@@ -610,7 +612,7 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
|
||||
return out.str();
|
||||
}
|
||||
|
||||
class Context : public std::enable_shared_from_this<Context> {
|
||||
class Context {
|
||||
protected:
|
||||
Value values_;
|
||||
std::shared_ptr<Context> parent_;
|
||||
@@ -706,7 +708,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
|
||||
|
||||
class TemplateToken {
|
||||
public:
|
||||
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
|
||||
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall };
|
||||
|
||||
static std::string typeToString(Type t) {
|
||||
switch (t) {
|
||||
@@ -729,6 +731,8 @@ public:
|
||||
case Type::EndGeneration: return "endgeneration";
|
||||
case Type::Break: return "break";
|
||||
case Type::Continue: return "continue";
|
||||
case Type::Call: return "call";
|
||||
case Type::EndCall: return "endcall";
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
@@ -846,6 +850,17 @@ struct LoopControlTemplateToken : public TemplateToken {
|
||||
LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
|
||||
};
|
||||
|
||||
struct CallTemplateToken : public TemplateToken {
|
||||
std::shared_ptr<Expression> expr;
|
||||
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
|
||||
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
|
||||
};
|
||||
|
||||
struct EndCallTemplateToken : public TemplateToken {
|
||||
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
|
||||
: TemplateToken(Type::EndCall, loc, pre, post) {}
|
||||
};
|
||||
|
||||
class TemplateNode {
|
||||
Location location_;
|
||||
protected:
|
||||
@@ -1047,36 +1062,48 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
|
||||
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
|
||||
if (!name) throw std::runtime_error("MacroNode.name is null");
|
||||
if (!body) throw std::runtime_error("MacroNode.body is null");
|
||||
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
auto call_context = macro_context;
|
||||
|
||||
// Use init-capture to avoid dangling 'this' pointer and circular references
|
||||
auto callable = Value::callable([weak_context = std::weak_ptr<Context>(context),
|
||||
name = name, params = params, body = body,
|
||||
named_param_positions = named_param_positions]
|
||||
(const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
|
||||
auto context_locked = weak_context.lock();
|
||||
if (!context_locked) throw std::runtime_error("Macro context no longer valid");
|
||||
auto execution_context = Context::make(Value::object(), context_locked);
|
||||
|
||||
if (call_context->contains("caller")) {
|
||||
execution_context->set("caller", call_context->get("caller"));
|
||||
}
|
||||
|
||||
std::vector<bool> param_set(params.size(), false);
|
||||
for (size_t i = 0, n = args.args.size(); i < n; i++) {
|
||||
auto & arg = args.args[i];
|
||||
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
|
||||
param_set[i] = true;
|
||||
auto & param_name = params[i].first;
|
||||
call_context->set(param_name, arg);
|
||||
const auto & param_name = params[i].first;
|
||||
execution_context->set(param_name, arg);
|
||||
}
|
||||
for (auto & [arg_name, value] : args.kwargs) {
|
||||
auto it = named_param_positions.find(arg_name);
|
||||
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
|
||||
|
||||
call_context->set(arg_name, value);
|
||||
execution_context->set(arg_name, value);
|
||||
param_set[it->second] = true;
|
||||
}
|
||||
// Set default values for parameters that were not passed
|
||||
for (size_t i = 0, n = params.size(); i < n; i++) {
|
||||
if (!param_set[i] && params[i].second != nullptr) {
|
||||
auto val = params[i].second->evaluate(context);
|
||||
call_context->set(params[i].first, val);
|
||||
auto val = params[i].second->evaluate(call_context);
|
||||
execution_context->set(params[i].first, val);
|
||||
}
|
||||
}
|
||||
return body->render(call_context);
|
||||
return body->render(execution_context);
|
||||
});
|
||||
macro_context->set(name->get_name(), callable);
|
||||
context->set(name->get_name(), callable);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1611,6 +1638,44 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class CallNode : public TemplateNode {
|
||||
std::shared_ptr<Expression> expr;
|
||||
std::shared_ptr<TemplateNode> body;
|
||||
|
||||
public:
|
||||
CallNode(const Location & loc, std::shared_ptr<Expression> && e, std::shared_ptr<TemplateNode> && b)
|
||||
: TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {}
|
||||
|
||||
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
|
||||
if (!expr) throw std::runtime_error("CallNode.expr is null");
|
||||
if (!body) throw std::runtime_error("CallNode.body is null");
|
||||
|
||||
// Use init-capture to avoid dangling 'this' pointer and circular references
|
||||
auto caller = Value::callable([weak_context = std::weak_ptr<Context>(context), body=body]
|
||||
(const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
|
||||
auto context_locked = weak_context.lock();
|
||||
if (!context_locked) throw std::runtime_error("Caller context no longer valid");
|
||||
return Value(body->render(context_locked));
|
||||
});
|
||||
|
||||
context->set("caller", caller);
|
||||
|
||||
auto call_expr = dynamic_cast<CallExpr*>(expr.get());
|
||||
if (!call_expr) {
|
||||
throw std::runtime_error("Invalid call block syntax - expected function call");
|
||||
}
|
||||
|
||||
Value function = call_expr->object->evaluate(context);
|
||||
if (!function.is_callable()) {
|
||||
throw std::runtime_error("Call target must be callable: " + function.dump());
|
||||
}
|
||||
ArgumentsValue args = call_expr->args.evaluate(context);
|
||||
|
||||
Value result = function.call(context, args);
|
||||
out << result.to_str();
|
||||
}
|
||||
};
|
||||
|
||||
class FilterExpr : public Expression {
|
||||
std::vector<std::shared_ptr<Expression>> parts;
|
||||
public:
|
||||
@@ -2320,7 +2385,7 @@ private:
|
||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
|
||||
static std::regex expr_open_regex(R"(\{\{([-~])?)");
|
||||
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
|
||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
|
||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)");
|
||||
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
|
||||
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
|
||||
static std::regex block_close_regex(R"(\s*([-~])?%\})");
|
||||
@@ -2443,6 +2508,15 @@ private:
|
||||
} else if (keyword == "endmacro") {
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
|
||||
} else if (keyword == "call") {
|
||||
auto expr = parseExpression();
|
||||
if (!expr) throw std::runtime_error("Expected expression in call block");
|
||||
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<CallTemplateToken>(location, pre_space, post_space, std::move(expr)));
|
||||
} else if (keyword == "endcall") {
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<EndCallTemplateToken>(location, pre_space, post_space));
|
||||
} else if (keyword == "filter") {
|
||||
auto filter = parseExpression();
|
||||
if (!filter) throw std::runtime_error("Expected expression in filter block");
|
||||
@@ -2575,6 +2649,12 @@ private:
|
||||
throw unterminated(**start);
|
||||
}
|
||||
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
|
||||
} else if (auto call_token = dynamic_cast<CallTemplateToken*>(token.get())) {
|
||||
auto body = parseTemplate(begin, it, end);
|
||||
if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) {
|
||||
throw unterminated(**start);
|
||||
}
|
||||
children.emplace_back(std::make_shared<CallNode>(token->location, std::move(call_token->expr), std::move(body)));
|
||||
} else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
|
||||
auto body = parseTemplate(begin, it, end);
|
||||
if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
|
||||
@@ -2588,6 +2668,7 @@ private:
|
||||
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndSetTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndCallTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
|
||||
|| dynamic_cast<EndIfTemplateToken*>(token.get())
|
||||
|| dynamic_cast<ElseTemplateToken*>(token.get())
|
||||
|
||||
Reference in New Issue
Block a user