mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
json : support enum values within allOf (#15830)
This commit is contained in:
@@ -843,9 +843,10 @@ public:
|
||||
_build_object_rule(
|
||||
properties, required, name,
|
||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
|
||||
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
|
||||
std::unordered_set<std::string> required;
|
||||
std::vector<std::pair<std::string, json>> properties;
|
||||
std::map<std::string, size_t> enum_values;
|
||||
std::string hybrid_name = name;
|
||||
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
||||
if (comp_schema.contains("$ref")) {
|
||||
@@ -857,6 +858,14 @@ public:
|
||||
required.insert(prop.key());
|
||||
}
|
||||
}
|
||||
} else if (comp_schema.contains("enum")) {
|
||||
for (const auto & v : comp_schema["enum"]) {
|
||||
const auto rule = _generate_constant_rule(v);
|
||||
if (enum_values.find(rule) == enum_values.end()) {
|
||||
enum_values[rule] = 0;
|
||||
}
|
||||
enum_values[rule] += 1;
|
||||
}
|
||||
} else {
|
||||
// todo warning
|
||||
}
|
||||
@@ -870,6 +879,17 @@ public:
|
||||
add_component(t, true);
|
||||
}
|
||||
}
|
||||
if (!enum_values.empty()) {
|
||||
std::vector<std::string> enum_intersection;
|
||||
for (const auto & p : enum_values) {
|
||||
if (p.second == schema["allOf"].size()) {
|
||||
enum_intersection.push_back(p.first);
|
||||
}
|
||||
}
|
||||
if (!enum_intersection.empty()) {
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||
|
||||
@@ -586,9 +586,10 @@ class SchemaConverter:
|
||||
properties = list(schema.get('properties', {}).items())
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||
|
||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
enum_sets = []
|
||||
hybrid_name = name
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get('$ref')) is not None:
|
||||
@@ -600,6 +601,9 @@ class SchemaConverter:
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
if 'enum' in comp_schema:
|
||||
enum_sets.append(set(comp_schema['enum']))
|
||||
|
||||
for t in schema['allOf']:
|
||||
if 'anyOf' in t:
|
||||
for tt in t['anyOf']:
|
||||
@@ -607,6 +611,15 @@ class SchemaConverter:
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
if enum_sets:
|
||||
enum_intersection = enum_sets[0]
|
||||
for s in enum_sets[1:]:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
|
||||
@@ -1209,6 +1209,51 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"allOf with enum schema",
|
||||
R"""({
|
||||
"allOf": [
|
||||
{"$ref": "#/definitions/foo"}
|
||||
],
|
||||
"definitions": {
|
||||
"foo": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"]
|
||||
}
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"a\"" | "\"b\"") space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"allOf with multiple enum schemas",
|
||||
R"""({
|
||||
"allOf": [
|
||||
{"$ref": "#/definitions/foo"},
|
||||
{"$ref": "#/definitions/bar"}
|
||||
],
|
||||
"definitions": {
|
||||
"foo": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b", "c"]
|
||||
},
|
||||
"bar": {
|
||||
"type": "string",
|
||||
"enum": ["b", "c", "d"]
|
||||
}
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"b\"" | "\"c\"") space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"conflicting names",
|
||||
|
||||
@@ -631,9 +631,10 @@ export class SchemaConverter {
|
||||
const required = new Set(schema.required || []);
|
||||
const properties = Object.entries(schema.properties ?? {});
|
||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
||||
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
|
||||
} else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) {
|
||||
const required = new Set();
|
||||
const properties = [];
|
||||
const enumSets = [];
|
||||
const addComponent = (compSchema, isRequired) => {
|
||||
const ref = compSchema.$ref;
|
||||
if (ref !== undefined) {
|
||||
@@ -648,6 +649,10 @@ export class SchemaConverter {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ('enum' in compSchema) {
|
||||
enumSets.push(new Set(compSchema.enum || []));
|
||||
}
|
||||
};
|
||||
|
||||
for (const t of schema.allOf) {
|
||||
@@ -660,6 +665,14 @@ export class SchemaConverter {
|
||||
}
|
||||
}
|
||||
|
||||
if (enumSets.length > 0) {
|
||||
const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v))));
|
||||
if (enumIntersection.size > 0) {
|
||||
const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b));
|
||||
const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
|
||||
return this._addRule(ruleName, rule);
|
||||
}
|
||||
}
|
||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
|
||||
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
||||
const items = schema.items ?? schema.prefixItems;
|
||||
|
||||
Reference in New Issue
Block a user