move cache stack to advance stack

This commit is contained in:
Clarissa Miranda
2024-10-14 17:13:40 +11:00
parent cb1632b593
commit 901a3479b1
4 changed files with 39 additions and 36 deletions

View File

@@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char(
// additionally memorizes the stack to its possible stacks by mapping
// < llama_grammar_stack, llama_grammar_stacks >
struct VectorPointerHash {
size_t operator()(const llama_grammar_stack & v) const {
size_t seed = v.size();
for (const auto* ptr : v) {
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
static std::unordered_map<
llama_grammar_stack,
llama_grammar_stacks,
VectorPointerHash>
llama_grammar_stacks_cache = {};
static void llama_grammar_advance_stack_memo(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
llama_grammar_stacks & new_stacks);
llama_grammar_stacks & new_stacks,
llama_grammar_stacks_cache & stacks_cache);
static void llama_grammar_advance_stack_memo_impl(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
llama_grammar_stacks & new_stacks) {
llama_grammar_stacks & new_stacks,
llama_grammar_stacks_cache & stacks_cache) {
if (stack.empty()) {
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
new_stacks.emplace_back(stack);
@@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl(
// if alternate is nonempty, add to stack
new_stack.push_back(subpos);
}
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks);
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
while (!llama_grammar_is_end_of_sequence(subpos)) {
// scan to end of alternate def
subpos++;
@@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl(
static void llama_grammar_advance_stack_memo(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
llama_grammar_stacks & new_stacks) {
llama_grammar_stacks & new_stacks,
llama_grammar_stacks_cache & stacks_cache) {
llama_grammar_stacks advanced_stacks;
// Look if stack is already in memory
auto it = llama_grammar_stacks_cache.find(stack);
if (it != llama_grammar_stacks_cache.end()) {
auto it = stacks_cache.find(stack);
if (it != stacks_cache.end()) {
advanced_stacks = it->second;
} else {
// Advance stacks with memorization
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks);
llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks));
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
stacks_cache.insert(make_pair(stack, advanced_stacks));
}
// Add the advanced stacks to new_stacks avoiding duplicates
for (const auto & new_stack : advanced_stacks) {
@@ -934,7 +921,8 @@ void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & stacks_new) {
llama_grammar_stacks & stacks_new,
llama_grammar_stacks_cache & stacks_cache) {
stacks_new.clear();
stacks_new.reserve(stacks.size());
@@ -952,7 +940,7 @@ void llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new);
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache);
}
}
}
@@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
// Clear stacks cache
llama_grammar_stacks_cache.clear();
const llama_grammar_element * pos;
// copy rule definitions into vectors
@@ -1048,6 +1034,7 @@ struct llama_grammar * llama_grammar_init_impl(
// loop over alternates of start rule to build initial stacks
llama_grammar_stacks stacks;
llama_grammar_stacks_cache stacks_cache;
pos = vec_rules[start_rule_index].data();
do {
llama_grammar_stack stack;
@@ -1055,7 +1042,7 @@ struct llama_grammar * llama_grammar_init_impl(
// if alternate is nonempty, add to stack
stack.push_back(pos);
}
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
while (!llama_grammar_is_end_of_sequence(pos)) {
// scan to end of alternate def
pos++;
@@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl(
}
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
// Clear stacks cache
llama_grammar_stacks_cache.clear();
llama_grammar_parser parser;
// if there is a grammar, parse it
@@ -1128,6 +1113,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
// loop over alternates of start rule to build initial stacks
llama_grammar_stacks stacks;
llama_grammar_stacks_cache stacks_cache;
pos = vec_rules[start_rule_index].data();
do {
llama_grammar_stack stack;
@@ -1135,7 +1121,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
// if alternate is nonempty, add to stack
stack.push_back(pos);
}
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
while (!llama_grammar_is_end_of_sequence(pos)) {
// scan to end of alternate def
pos++;
@@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new;
llama_grammar_stacks_cache stacks_cache;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache);
grammar.stacks = std::move(stacks_new);
}