mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
json : support enum values within allOf (#15830)
This commit is contained in:
@@ -843,9 +843,10 @@ public:
|
|||||||
_build_object_rule(
|
_build_object_rule(
|
||||||
properties, required, name,
|
properties, required, name,
|
||||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||||
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
|
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
|
||||||
std::unordered_set<std::string> required;
|
std::unordered_set<std::string> required;
|
||||||
std::vector<std::pair<std::string, json>> properties;
|
std::vector<std::pair<std::string, json>> properties;
|
||||||
|
std::map<std::string, size_t> enum_values;
|
||||||
std::string hybrid_name = name;
|
std::string hybrid_name = name;
|
||||||
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
||||||
if (comp_schema.contains("$ref")) {
|
if (comp_schema.contains("$ref")) {
|
||||||
@@ -857,6 +858,14 @@ public:
|
|||||||
required.insert(prop.key());
|
required.insert(prop.key());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (comp_schema.contains("enum")) {
|
||||||
|
for (const auto & v : comp_schema["enum"]) {
|
||||||
|
const auto rule = _generate_constant_rule(v);
|
||||||
|
if (enum_values.find(rule) == enum_values.end()) {
|
||||||
|
enum_values[rule] = 0;
|
||||||
|
}
|
||||||
|
enum_values[rule] += 1;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// todo warning
|
// todo warning
|
||||||
}
|
}
|
||||||
@@ -870,6 +879,17 @@ public:
|
|||||||
add_component(t, true);
|
add_component(t, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!enum_values.empty()) {
|
||||||
|
std::vector<std::string> enum_intersection;
|
||||||
|
for (const auto & p : enum_values) {
|
||||||
|
if (p.second == schema["allOf"].size()) {
|
||||||
|
enum_intersection.push_back(p.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!enum_intersection.empty()) {
|
||||||
|
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||||
|
}
|
||||||
|
}
|
||||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||||
|
|||||||
@@ -586,9 +586,10 @@ class SchemaConverter:
|
|||||||
properties = list(schema.get('properties', {}).items())
|
properties = list(schema.get('properties', {}).items())
|
||||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||||
|
|
||||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
|
||||||
required = set()
|
required = set()
|
||||||
properties = []
|
properties = []
|
||||||
|
enum_sets = []
|
||||||
hybrid_name = name
|
hybrid_name = name
|
||||||
def add_component(comp_schema, is_required):
|
def add_component(comp_schema, is_required):
|
||||||
if (ref := comp_schema.get('$ref')) is not None:
|
if (ref := comp_schema.get('$ref')) is not None:
|
||||||
@@ -600,6 +601,9 @@ class SchemaConverter:
|
|||||||
if is_required:
|
if is_required:
|
||||||
required.add(prop_name)
|
required.add(prop_name)
|
||||||
|
|
||||||
|
if 'enum' in comp_schema:
|
||||||
|
enum_sets.append(set(comp_schema['enum']))
|
||||||
|
|
||||||
for t in schema['allOf']:
|
for t in schema['allOf']:
|
||||||
if 'anyOf' in t:
|
if 'anyOf' in t:
|
||||||
for tt in t['anyOf']:
|
for tt in t['anyOf']:
|
||||||
@@ -607,6 +611,15 @@ class SchemaConverter:
|
|||||||
else:
|
else:
|
||||||
add_component(t, is_required=True)
|
add_component(t, is_required=True)
|
||||||
|
|
||||||
|
if enum_sets:
|
||||||
|
enum_intersection = enum_sets[0]
|
||||||
|
for s in enum_sets[1:]:
|
||||||
|
enum_intersection &= s
|
||||||
|
|
||||||
|
if enum_intersection:
|
||||||
|
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||||
|
|
||||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||||
|
|||||||
@@ -1209,6 +1209,51 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"allOf with enum schema",
|
||||||
|
R"""({
|
||||||
|
"allOf": [
|
||||||
|
{"$ref": "#/definitions/foo"}
|
||||||
|
],
|
||||||
|
"definitions": {
|
||||||
|
"foo": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["a", "b"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
root ::= ("\"a\"" | "\"b\"") space
|
||||||
|
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"allOf with multiple enum schemas",
|
||||||
|
R"""({
|
||||||
|
"allOf": [
|
||||||
|
{"$ref": "#/definitions/foo"},
|
||||||
|
{"$ref": "#/definitions/bar"}
|
||||||
|
],
|
||||||
|
"definitions": {
|
||||||
|
"foo": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["a", "b", "c"]
|
||||||
|
},
|
||||||
|
"bar": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["b", "c", "d"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
root ::= ("\"b\"" | "\"c\"") space
|
||||||
|
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
"conflicting names",
|
"conflicting names",
|
||||||
|
|||||||
@@ -631,9 +631,10 @@ export class SchemaConverter {
|
|||||||
const required = new Set(schema.required || []);
|
const required = new Set(schema.required || []);
|
||||||
const properties = Object.entries(schema.properties ?? {});
|
const properties = Object.entries(schema.properties ?? {});
|
||||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
||||||
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
|
} else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) {
|
||||||
const required = new Set();
|
const required = new Set();
|
||||||
const properties = [];
|
const properties = [];
|
||||||
|
const enumSets = [];
|
||||||
const addComponent = (compSchema, isRequired) => {
|
const addComponent = (compSchema, isRequired) => {
|
||||||
const ref = compSchema.$ref;
|
const ref = compSchema.$ref;
|
||||||
if (ref !== undefined) {
|
if (ref !== undefined) {
|
||||||
@@ -648,6 +649,10 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('enum' in compSchema) {
|
||||||
|
enumSets.push(new Set(compSchema.enum || []));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const t of schema.allOf) {
|
for (const t of schema.allOf) {
|
||||||
@@ -660,6 +665,14 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enumSets.length > 0) {
|
||||||
|
const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v))));
|
||||||
|
if (enumIntersection.size > 0) {
|
||||||
|
const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b));
|
||||||
|
const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
|
||||||
|
return this._addRule(ruleName, rule);
|
||||||
|
}
|
||||||
|
}
|
||||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
|
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
|
||||||
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
||||||
const items = schema.items ?? schema.prefixItems;
|
const items = schema.items ?? schema.prefixItems;
|
||||||
|
|||||||
Reference in New Issue
Block a user