mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
common : handle unicode during partial json parsing (#16526)
* common : handle unicode during partial json parsing * common : set missing `ensure_ascii = true` during json dump
This commit is contained in:
@@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
@@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -168,6 +169,47 @@ bool common_json_parse(
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
@@ -186,6 +228,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
@@ -205,6 +250,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
@@ -230,6 +278,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
|
||||
@@ -524,6 +524,64 @@ static void test_json_with_dumped_args() {
|
||||
R"({"foo": "bar", "args": {"arg1": [)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||
);
|
||||
|
||||
// Unicode tests
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud8)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud80)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_positions() {
|
||||
|
||||
@@ -58,7 +58,7 @@ static void test_json_healing() {
|
||||
for (const auto & input : inputs) {
|
||||
common_json out;
|
||||
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||
assert_equals<std::string>(expected, out.json.dump());
|
||||
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
|
||||
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||
}
|
||||
};
|
||||
@@ -228,6 +228,56 @@ static void test_json_healing() {
|
||||
R"({"key":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
// Test unicode escape sequences
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(0000$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u00)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud300)",
|
||||
},
|
||||
R"({"a":"\ud300$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(\udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\u)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(dc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\udc00)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
Reference in New Issue
Block a user