common : handle unicode during partial json parsing (#16526)

* common : handle unicode during partial json parsing

* common : set missing `ensure_ascii = true` during json dump
This commit is contained in:
Aldehir Rojas
2025-10-12 08:18:47 -05:00
committed by GitHub
parent 4b2dae383d
commit 2c301e91ab
4 changed files with 162 additions and 3 deletions

View File

@@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
/* .is_partial = */ false,
};
}
@@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {

View File

@@ -5,6 +5,7 @@
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
@@ -168,6 +169,47 @@ bool common_json_parse(
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
@@ -186,6 +228,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
@@ -205,6 +250,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
@@ -230,6 +278,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {

View File

@@ -524,6 +524,64 @@ static void test_json_with_dumped_args() {
R"({"foo": "bar", "args": {"arg1": [)",
R"({"foo":"bar","args":"{\"arg1\":["})"
);
// Unicode tests
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u0)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u00)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u000)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u0000)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud8)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud80)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
);
}
static void test_positions() {

View File

@@ -58,7 +58,7 @@ static void test_json_healing() {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump());
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
@@ -228,6 +228,56 @@ static void test_json_healing() {
R"({"key":"$foo"})",
R"(:"$foo)"
);
// Test unicode escape sequences
test(
{
R"({"a":"\u)",
},
R"({"a":"\u0000$foo"})",
R"(0000$foo)"
);
test(
{
R"({"a":"\u00)",
},
R"({"a":"\u0000$foo"})",
R"(00$foo)"
);
test(
{
R"({"a":"\ud300)",
},
R"({"a":"\ud300$foo"})",
R"($foo)"
);
test(
{
R"({"a":"\ud800)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(\udc00$foo)"
);
test(
{
R"({"a":"\ud800\)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(udc00$foo)"
);
test(
{
R"({"a":"\ud800\u)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(dc00$foo)"
);
test(
{
R"({"a":"\ud800\udc00)",
},
R"({"a":"\ud800\udc00$foo"})",
R"($foo)"
);
}
int main() {