mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	* Added integration tests for GBNF parser to validate correctness of parsing, as well as correctness of string matching. Intended for use to pin behavior while working on performance improvements. * Fixing whitespace errors and cleaning error message alert to be clearer. * Removing hacky include to llama.cpp from grammar integration test now that needed functions are available via internal API. * Comment cleanup. * Reorganizing tests for readability. * Cleaning up debug message to make a bit more sense.
		
			
				
	
	
		
			244 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			244 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#ifdef NDEBUG
 | 
						|
#undef NDEBUG
 | 
						|
#endif
 | 
						|
 | 
						|
#define LLAMA_API_INTERNAL
 | 
						|
 | 
						|
#include "ggml.h"
 | 
						|
#include "llama.h"
 | 
						|
#include "grammar-parser.h"
 | 
						|
#include "unicode.h"
 | 
						|
#include <cassert>
 | 
						|
#include <string>
 | 
						|
 | 
						|
static void test_simple_grammar() {
 | 
						|
    // Test case for a simple grammar
 | 
						|
    const std::string grammar_str = R"""(root ::= expr
 | 
						|
expr ::= term ("+" term)*
 | 
						|
term ::= number
 | 
						|
number ::= [0-9]+)""";
 | 
						|
 | 
						|
    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 | 
						|
 | 
						|
    // Ensure we parsed correctly
 | 
						|
    assert(!parsed_grammar.rules.empty());
 | 
						|
 | 
						|
    // Ensure we have a root node
 | 
						|
    assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
 | 
						|
 | 
						|
    std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
 | 
						|
    llama_grammar* grammar = llama_grammar_init(
 | 
						|
        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
 | 
						|
 | 
						|
    std::string input = "123+456";
 | 
						|
 | 
						|
    auto decoded = decode_utf8(input, {});
 | 
						|
 | 
						|
    const auto & code_points = decoded.first;
 | 
						|
 | 
						|
    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 | 
						|
        auto prev_stacks = grammar->stacks;
 | 
						|
        grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
 | 
						|
        assert(!grammar->stacks.empty());
 | 
						|
    }
 | 
						|
 | 
						|
    bool completed_grammar = false;
 | 
						|
 | 
						|
    for (const auto & stack : grammar->stacks) {
 | 
						|
        if (stack.empty()) {
 | 
						|
            completed_grammar = true;
 | 
						|
            break;
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    assert(completed_grammar);
 | 
						|
 | 
						|
    // Clean up allocated memory
 | 
						|
    llama_grammar_free(grammar);
 | 
						|
}
 | 
						|
 | 
						|
static void test_complex_grammar() {
 | 
						|
    // Test case for a more complex grammar, with both failure strings and success strings
 | 
						|
    const std::string grammar_str = R"""(root ::= expression
 | 
						|
expression ::= term ws (("+"|"-") ws term)*
 | 
						|
term ::= factor ws (("*"|"/") ws factor)*
 | 
						|
factor ::= number | variable | "(" expression ")" | function-call
 | 
						|
number ::= [0-9]+
 | 
						|
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
 | 
						|
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
 | 
						|
ws ::= [ \t\n\r]?)""";
 | 
						|
 | 
						|
    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 | 
						|
 | 
						|
    // Ensure we parsed correctly
 | 
						|
    assert(!parsed_grammar.rules.empty());
 | 
						|
 | 
						|
    // Ensure we have a root node
 | 
						|
    assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
 | 
						|
 | 
						|
    std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
 | 
						|
    llama_grammar* grammar = llama_grammar_init(
 | 
						|
        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
 | 
						|
 | 
						|
    // Save the original grammar stacks so that we can reset after every new string we want to test
 | 
						|
    auto original_stacks = grammar->stacks;
 | 
						|
 | 
						|
    // Test a few strings
 | 
						|
    std::vector<std::string> test_strings_pass = {
 | 
						|
        "42",
 | 
						|
        "1*2*3*4*5",
 | 
						|
        "x",
 | 
						|
        "x+10",
 | 
						|
        "x1+y2",
 | 
						|
        "(a+b)*(c-d)",
 | 
						|
        "func()",
 | 
						|
        "func(x,y+2)",
 | 
						|
        "a*(b+c)-d/e",
 | 
						|
        "f(g(x),h(y,z))",
 | 
						|
        "x + 10",
 | 
						|
        "x1 + y2",
 | 
						|
        "(a + b) * (c - d)",
 | 
						|
        "func()",
 | 
						|
        "func(x, y + 2)",
 | 
						|
        "a * (b + c) - d / e",
 | 
						|
        "f(g(x), h(y, z))",
 | 
						|
        "123+456",
 | 
						|
        "123*456*789-123/456+789*123",
 | 
						|
        "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
 | 
						|
    };
 | 
						|
 | 
						|
    std::vector<std::string> test_strings_fail = {
 | 
						|
        "+",
 | 
						|
        "/ 3x",
 | 
						|
        "x + + y",
 | 
						|
        "a * / b",
 | 
						|
        "func(,)",
 | 
						|
        "func(x y)",
 | 
						|
        "(a + b",
 | 
						|
        "x + y)",
 | 
						|
        "a + b * (c - d",
 | 
						|
        "42 +",
 | 
						|
        "x +",
 | 
						|
        "x + 10 +",
 | 
						|
        "(a + b) * (c - d",
 | 
						|
        "func(",
 | 
						|
        "func(x, y + 2",
 | 
						|
        "a * (b + c) - d /",
 | 
						|
        "f(g(x), h(y, z)",
 | 
						|
        "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
 | 
						|
    };
 | 
						|
 | 
						|
    // Passing strings
 | 
						|
    for (const auto & test_string : test_strings_pass) {
 | 
						|
        auto decoded = decode_utf8(test_string, {});
 | 
						|
 | 
						|
        const auto & code_points = decoded.first;
 | 
						|
 | 
						|
        int pos = 0;
 | 
						|
        for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 | 
						|
            ++pos;
 | 
						|
            auto prev_stacks = grammar->stacks;
 | 
						|
            grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
 | 
						|
 | 
						|
            // Expect that each code point will not cause the grammar to fail
 | 
						|
            if (grammar->stacks.empty()) {
 | 
						|
                fprintf(stdout, "Error at position %d\n", pos);
 | 
						|
                fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
 | 
						|
                fprintf(stderr, "Input string is %s:\n", test_string.c_str());
 | 
						|
            }
 | 
						|
            assert(!grammar->stacks.empty());
 | 
						|
        }
 | 
						|
 | 
						|
        bool completed_grammar = false;
 | 
						|
 | 
						|
        for (const auto & stack : grammar->stacks) {
 | 
						|
            if (stack.empty()) {
 | 
						|
                completed_grammar = true;
 | 
						|
                break;
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        assert(completed_grammar);
 | 
						|
 | 
						|
        // Reset the grammar stacks
 | 
						|
        grammar->stacks = original_stacks;
 | 
						|
    }
 | 
						|
 | 
						|
    // Failing strings
 | 
						|
    for (const auto & test_string : test_strings_fail) {
 | 
						|
        auto decoded = decode_utf8(test_string, {});
 | 
						|
 | 
						|
        const auto & code_points = decoded.first;
 | 
						|
        bool parse_failed = false;
 | 
						|
 | 
						|
        for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 | 
						|
            auto prev_stacks = grammar->stacks;
 | 
						|
            grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
 | 
						|
            if (grammar->stacks.empty()) {
 | 
						|
                parse_failed = true;
 | 
						|
                break;
 | 
						|
            }
 | 
						|
            assert(!grammar->stacks.empty());
 | 
						|
        }
 | 
						|
 | 
						|
        bool completed_grammar = false;
 | 
						|
 | 
						|
        for (const auto & stack : grammar->stacks) {
 | 
						|
            if (stack.empty()) {
 | 
						|
                completed_grammar = true;
 | 
						|
                break;
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        // Ensure that the grammar is not completed, or that each string failed to match as-expected
 | 
						|
        assert((!completed_grammar) || parse_failed);
 | 
						|
 | 
						|
        // Reset the grammar stacks
 | 
						|
        grammar->stacks = original_stacks;
 | 
						|
    }
 | 
						|
 | 
						|
    // Clean up allocated memory
 | 
						|
    llama_grammar_free(grammar);
 | 
						|
}
 | 
						|
 | 
						|
static void test_failure_missing_root() {
 | 
						|
    // Test case for a grammar that is missing a root rule
 | 
						|
    const std::string grammar_str = R"""(rot ::= expr
 | 
						|
expr ::= term ("+" term)*
 | 
						|
term ::= number
 | 
						|
number ::= [0-9]+)""";
 | 
						|
 | 
						|
    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 | 
						|
 | 
						|
    // Ensure we parsed correctly
 | 
						|
    assert(!parsed_grammar.rules.empty());
 | 
						|
 | 
						|
    // Ensure we do NOT have a root node
 | 
						|
    assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
 | 
						|
}
 | 
						|
 | 
						|
static void test_failure_missing_reference() {
 | 
						|
    // Test case for a grammar that is missing a referenced rule
 | 
						|
    const std::string grammar_str = R"""(root ::= expr
 | 
						|
expr ::= term ("+" term)*
 | 
						|
term ::= numero
 | 
						|
number ::= [0-9]+)""";
 | 
						|
 | 
						|
    fprintf(stderr, "Expected error:  ");
 | 
						|
 | 
						|
    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 | 
						|
 | 
						|
    // Ensure we did NOT parsed correctly
 | 
						|
    assert(parsed_grammar.rules.empty());
 | 
						|
 | 
						|
    fprintf(stderr, "End of expected error. Test successful.\n");
 | 
						|
}
 | 
						|
 | 
						|
int main() {
 | 
						|
    test_simple_grammar();
 | 
						|
    test_complex_grammar();
 | 
						|
    test_failure_missing_root();
 | 
						|
    test_failure_missing_reference();
 | 
						|
    return 0;
 | 
						|
}
 |