mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	move cache stack to advance stack
This commit is contained in:
		@@ -15,10 +15,11 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
 | 
				
			|||||||
          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 | 
					          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    size_t pos = 0;
 | 
					    size_t pos = 0;
 | 
				
			||||||
 | 
					    llama_grammar_stacks_cache stacks_cache;
 | 
				
			||||||
    for (const auto & cpt : cpts) {
 | 
					    for (const auto & cpt : cpts) {
 | 
				
			||||||
        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 | 
					        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
 | 
					        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (stacks_cur.empty()) {
 | 
					        if (stacks_cur.empty()) {
 | 
				
			||||||
            error_pos = pos;
 | 
					            error_pos = pos;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char(
 | 
				
			|||||||
// additionally memorizes the stack to its possible stacks by mapping
 | 
					// additionally memorizes the stack to its possible stacks by mapping
 | 
				
			||||||
// < llama_grammar_stack, llama_grammar_stacks >
 | 
					// < llama_grammar_stack, llama_grammar_stacks >
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct VectorPointerHash {
 | 
					 | 
				
			||||||
    size_t operator()(const llama_grammar_stack & v) const {
 | 
					 | 
				
			||||||
        size_t seed = v.size();
 | 
					 | 
				
			||||||
        for (const auto* ptr : v) {
 | 
					 | 
				
			||||||
            seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        return seed;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static std::unordered_map<
 | 
					 | 
				
			||||||
    llama_grammar_stack,
 | 
					 | 
				
			||||||
    llama_grammar_stacks,
 | 
					 | 
				
			||||||
    VectorPointerHash>
 | 
					 | 
				
			||||||
    llama_grammar_stacks_cache = {};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static void llama_grammar_advance_stack_memo(
 | 
					static void llama_grammar_advance_stack_memo(
 | 
				
			||||||
        const llama_grammar_rules  & rules,
 | 
					        const llama_grammar_rules  & rules,
 | 
				
			||||||
        const llama_grammar_stack  & stack,
 | 
					        const llama_grammar_stack  & stack,
 | 
				
			||||||
              llama_grammar_stacks & new_stacks);
 | 
					              llama_grammar_stacks & new_stacks,
 | 
				
			||||||
 | 
					              llama_grammar_stacks_cache & stacks_cache);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static void llama_grammar_advance_stack_memo_impl(
 | 
					static void llama_grammar_advance_stack_memo_impl(
 | 
				
			||||||
        const llama_grammar_rules  & rules,
 | 
					        const llama_grammar_rules  & rules,
 | 
				
			||||||
        const llama_grammar_stack  & stack,
 | 
					        const llama_grammar_stack  & stack,
 | 
				
			||||||
              llama_grammar_stacks & new_stacks) {
 | 
					              llama_grammar_stacks & new_stacks,
 | 
				
			||||||
 | 
					              llama_grammar_stacks_cache & stacks_cache) {
 | 
				
			||||||
    if (stack.empty()) {
 | 
					    if (stack.empty()) {
 | 
				
			||||||
        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
 | 
					        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
 | 
				
			||||||
            new_stacks.emplace_back(stack);
 | 
					            new_stacks.emplace_back(stack);
 | 
				
			||||||
@@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl(
 | 
				
			|||||||
                    // if alternate is nonempty, add to stack
 | 
					                    // if alternate is nonempty, add to stack
 | 
				
			||||||
                    new_stack.push_back(subpos);
 | 
					                    new_stack.push_back(subpos);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                llama_grammar_advance_stack_memo(rules, new_stack, new_stacks);
 | 
					                llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
 | 
				
			||||||
                while (!llama_grammar_is_end_of_sequence(subpos)) {
 | 
					                while (!llama_grammar_is_end_of_sequence(subpos)) {
 | 
				
			||||||
                    // scan to end of alternate def
 | 
					                    // scan to end of alternate def
 | 
				
			||||||
                    subpos++;
 | 
					                    subpos++;
 | 
				
			||||||
@@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl(
 | 
				
			|||||||
static void llama_grammar_advance_stack_memo(
 | 
					static void llama_grammar_advance_stack_memo(
 | 
				
			||||||
        const llama_grammar_rules  & rules,
 | 
					        const llama_grammar_rules  & rules,
 | 
				
			||||||
        const llama_grammar_stack  & stack,
 | 
					        const llama_grammar_stack  & stack,
 | 
				
			||||||
              llama_grammar_stacks & new_stacks) {
 | 
					              llama_grammar_stacks & new_stacks,
 | 
				
			||||||
 | 
					              llama_grammar_stacks_cache & stacks_cache) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_grammar_stacks advanced_stacks;
 | 
					    llama_grammar_stacks advanced_stacks;
 | 
				
			||||||
    // Look if stack is already in memory
 | 
					    // Look if stack is already in memory
 | 
				
			||||||
    auto it = llama_grammar_stacks_cache.find(stack);
 | 
					    auto it = stacks_cache.find(stack);
 | 
				
			||||||
    if (it != llama_grammar_stacks_cache.end()) {
 | 
					    if (it != stacks_cache.end()) {
 | 
				
			||||||
           advanced_stacks = it->second;
 | 
					           advanced_stacks = it->second;
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
        // Advance stacks with memorization 
 | 
					        // Advance stacks with memorization 
 | 
				
			||||||
        llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks);
 | 
					        llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
 | 
				
			||||||
        llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks));
 | 
					        stacks_cache.insert(make_pair(stack, advanced_stacks));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    // Add the advanced stacks to new_stacks avoiding duplicates
 | 
					    // Add the advanced stacks to new_stacks avoiding duplicates
 | 
				
			||||||
    for (const auto & new_stack : advanced_stacks) {
 | 
					    for (const auto & new_stack : advanced_stacks) {
 | 
				
			||||||
@@ -934,7 +921,8 @@ void llama_grammar_accept(
 | 
				
			|||||||
        const llama_grammar_rules  & rules,
 | 
					        const llama_grammar_rules  & rules,
 | 
				
			||||||
        const llama_grammar_stacks & stacks,
 | 
					        const llama_grammar_stacks & stacks,
 | 
				
			||||||
        const uint32_t               chr,
 | 
					        const uint32_t               chr,
 | 
				
			||||||
              llama_grammar_stacks & stacks_new) {
 | 
					              llama_grammar_stacks & stacks_new,
 | 
				
			||||||
 | 
					              llama_grammar_stacks_cache & stacks_cache) {
 | 
				
			||||||
    stacks_new.clear();
 | 
					    stacks_new.clear();
 | 
				
			||||||
    stacks_new.reserve(stacks.size());
 | 
					    stacks_new.reserve(stacks.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -952,7 +940,7 @@ void llama_grammar_accept(
 | 
				
			|||||||
            if (!llama_grammar_is_end_of_sequence(pos)) {
 | 
					            if (!llama_grammar_is_end_of_sequence(pos)) {
 | 
				
			||||||
                new_stack.push_back(pos);
 | 
					                new_stack.push_back(pos);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            llama_grammar_advance_stack_memo(rules, new_stack, stacks_new);
 | 
					            llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl(
 | 
				
			|||||||
        const llama_grammar_element ** rules,
 | 
					        const llama_grammar_element ** rules,
 | 
				
			||||||
        size_t n_rules,
 | 
					        size_t n_rules,
 | 
				
			||||||
        size_t start_rule_index) {
 | 
					        size_t start_rule_index) {
 | 
				
			||||||
    // Clear stacks cache
 | 
					 | 
				
			||||||
    llama_grammar_stacks_cache.clear();
 | 
					 | 
				
			||||||
    const llama_grammar_element * pos;
 | 
					    const llama_grammar_element * pos;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // copy rule definitions into vectors
 | 
					    // copy rule definitions into vectors
 | 
				
			||||||
@@ -1048,6 +1034,7 @@ struct llama_grammar * llama_grammar_init_impl(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    // loop over alternates of start rule to build initial stacks
 | 
					    // loop over alternates of start rule to build initial stacks
 | 
				
			||||||
    llama_grammar_stacks stacks;
 | 
					    llama_grammar_stacks stacks;
 | 
				
			||||||
 | 
					    llama_grammar_stacks_cache stacks_cache;
 | 
				
			||||||
    pos = vec_rules[start_rule_index].data();
 | 
					    pos = vec_rules[start_rule_index].data();
 | 
				
			||||||
    do {
 | 
					    do {
 | 
				
			||||||
        llama_grammar_stack stack;
 | 
					        llama_grammar_stack stack;
 | 
				
			||||||
@@ -1055,7 +1042,7 @@ struct llama_grammar * llama_grammar_init_impl(
 | 
				
			|||||||
            // if alternate is nonempty, add to stack
 | 
					            // if alternate is nonempty, add to stack
 | 
				
			||||||
            stack.push_back(pos);
 | 
					            stack.push_back(pos);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
 | 
					        llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
 | 
				
			||||||
        while (!llama_grammar_is_end_of_sequence(pos)) {
 | 
					        while (!llama_grammar_is_end_of_sequence(pos)) {
 | 
				
			||||||
            // scan to end of alternate def
 | 
					            // scan to end of alternate def
 | 
				
			||||||
            pos++;
 | 
					            pos++;
 | 
				
			||||||
@@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
 | 
					struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
 | 
				
			||||||
    // Clear stacks cache
 | 
					 | 
				
			||||||
    llama_grammar_stacks_cache.clear();
 | 
					 | 
				
			||||||
    llama_grammar_parser parser;
 | 
					    llama_grammar_parser parser;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // if there is a grammar, parse it
 | 
					    // if there is a grammar, parse it
 | 
				
			||||||
@@ -1128,6 +1113,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    // loop over alternates of start rule to build initial stacks
 | 
					    // loop over alternates of start rule to build initial stacks
 | 
				
			||||||
    llama_grammar_stacks stacks;
 | 
					    llama_grammar_stacks stacks;
 | 
				
			||||||
 | 
					    llama_grammar_stacks_cache stacks_cache;
 | 
				
			||||||
    pos = vec_rules[start_rule_index].data();
 | 
					    pos = vec_rules[start_rule_index].data();
 | 
				
			||||||
    do {
 | 
					    do {
 | 
				
			||||||
        llama_grammar_stack stack;
 | 
					        llama_grammar_stack stack;
 | 
				
			||||||
@@ -1135,7 +1121,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
 | 
				
			|||||||
            // if alternate is nonempty, add to stack
 | 
					            // if alternate is nonempty, add to stack
 | 
				
			||||||
            stack.push_back(pos);
 | 
					            stack.push_back(pos);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
 | 
					        llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
 | 
				
			||||||
        while (!llama_grammar_is_end_of_sequence(pos)) {
 | 
					        while (!llama_grammar_is_end_of_sequence(pos)) {
 | 
				
			||||||
            // scan to end of alternate def
 | 
					            // scan to end of alternate def
 | 
				
			||||||
            pos++;
 | 
					            pos++;
 | 
				
			||||||
@@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
 | 
				
			|||||||
    const auto & code_points = decoded.first;
 | 
					    const auto & code_points = decoded.first;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_grammar_stacks stacks_new;
 | 
					    llama_grammar_stacks stacks_new;
 | 
				
			||||||
 | 
					    llama_grammar_stacks_cache stacks_cache;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 | 
					    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 | 
				
			||||||
        llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
 | 
					        llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache);
 | 
				
			||||||
        grammar.stacks = std::move(stacks_new);
 | 
					        grammar.stacks = std::move(stacks_new);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@
 | 
				
			|||||||
#include "llama-impl.h"
 | 
					#include "llama-impl.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <map>
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct llama_vocab;
 | 
					struct llama_vocab;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -61,6 +62,18 @@ using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
 | 
				
			|||||||
const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
 | 
					const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
 | 
				
			||||||
      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
 | 
					      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct VectorPointerHash {
 | 
				
			||||||
 | 
					    size_t operator()(const llama_grammar_stack & v) const {
 | 
				
			||||||
 | 
					        size_t seed = v.size();
 | 
				
			||||||
 | 
					        for (const auto* ptr : v) {
 | 
				
			||||||
 | 
					            seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return seed;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// takes a set of possible pushdown stacks on a grammar, which are required to
 | 
					// takes a set of possible pushdown stacks on a grammar, which are required to
 | 
				
			||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
 | 
					// be positioned at a character range (see `llama_grammar_advance_stack`), and
 | 
				
			||||||
// produces the N possible stacks if the given char is accepted at those
 | 
					// produces the N possible stacks if the given char is accepted at those
 | 
				
			||||||
@@ -69,7 +82,8 @@ void llama_grammar_accept(
 | 
				
			|||||||
        const llama_grammar_rules  & rules,
 | 
					        const llama_grammar_rules  & rules,
 | 
				
			||||||
        const llama_grammar_stacks & stacks,
 | 
					        const llama_grammar_stacks & stacks,
 | 
				
			||||||
                          uint32_t   chr,
 | 
					                          uint32_t   chr,
 | 
				
			||||||
              llama_grammar_stacks & stacks_new);
 | 
					              llama_grammar_stacks & stacks_new,
 | 
				
			||||||
 | 
					              llama_grammar_stacks_cache & stacks_cache);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
 | 
					std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
 | 
				
			||||||
        const llama_grammar_rules      & rules,
 | 
					        const llama_grammar_rules      & rules,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -35,10 +35,11 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
 | 
				
			|||||||
    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
 | 
					    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
 | 
				
			||||||
          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 | 
					          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llama_grammar_stacks_cache stacks_cache;
 | 
				
			||||||
    for (const auto & cpt : cpts) {
 | 
					    for (const auto & cpt : cpts) {
 | 
				
			||||||
        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 | 
					        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
 | 
					        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (stacks_cur.empty()) {
 | 
					        if (stacks_cur.empty()) {
 | 
				
			||||||
            // no stacks means that the grammar failed to match at this point
 | 
					            // no stacks means that the grammar failed to match at this point
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user