mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : adds llama-grammar memorization stacks (#4218)
This commit is contained in:
		| @@ -682,6 +682,114 @@ static bool llama_grammar_match_partial_char( | ||||
|     return !is_positive_char; | ||||
| } | ||||
|  | ||||
| // transforms a grammar pushdown stack into N possible stacks, all ending | ||||
| // at a character range (terminal element) | ||||
| // additionally memorizes the stack to its possible stacks by mapping | ||||
| // < llama_grammar_stack, llama_grammar_stacks > | ||||
|  | ||||
| struct VectorPointerHash { | ||||
|     size_t operator()(const llama_grammar_stack & v) const { | ||||
|         size_t seed = v.size(); | ||||
|         for (const auto* ptr : v) { | ||||
|             seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); | ||||
|         } | ||||
|         return seed; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| static std::unordered_map< | ||||
|     llama_grammar_stack, | ||||
|     llama_grammar_stacks, | ||||
|     VectorPointerHash> | ||||
|     llama_grammar_stacks_cache = {}; | ||||
|  | ||||
| static void llama_grammar_advance_stack_memo( | ||||
|         const llama_grammar_rules  & rules, | ||||
|         const llama_grammar_stack  & stack, | ||||
|               llama_grammar_stacks & new_stacks); | ||||
|  | ||||
| static void llama_grammar_advance_stack_memo_impl( | ||||
|         const llama_grammar_rules  & rules, | ||||
|         const llama_grammar_stack  & stack, | ||||
|               llama_grammar_stacks & new_stacks) { | ||||
|     if (stack.empty()) { | ||||
|         if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { | ||||
|             new_stacks.emplace_back(stack); | ||||
|         } | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const llama_grammar_element * pos = stack.back(); | ||||
|  | ||||
|     switch (pos->type) { | ||||
|         case LLAMA_GRETYPE_RULE_REF: { | ||||
|             const size_t                  rule_id = static_cast<size_t>(pos->value); | ||||
|             const llama_grammar_element * subpos  = rules[rule_id].data(); | ||||
|             do { | ||||
|                 // init new stack without the top (pos) | ||||
|                 llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); | ||||
|                 if (!llama_grammar_is_end_of_sequence(pos + 1)) { | ||||
|                     // if this rule ref is followed by another element, add that to stack | ||||
|                     new_stack.push_back(pos + 1); | ||||
|                 } | ||||
|                 if (!llama_grammar_is_end_of_sequence(subpos)) { | ||||
|                     // if alternate is nonempty, add to stack | ||||
|                     new_stack.push_back(subpos); | ||||
|                 } | ||||
|                 llama_grammar_advance_stack_memo(rules, new_stack, new_stacks); | ||||
|                 while (!llama_grammar_is_end_of_sequence(subpos)) { | ||||
|                     // scan to end of alternate def | ||||
|                     subpos++; | ||||
|                 } | ||||
|                 if (subpos->type == LLAMA_GRETYPE_ALT) { | ||||
|                     // there's another alternate def of this rule to process | ||||
|                     subpos++; | ||||
|                 } else { | ||||
|                     break; | ||||
|                 } | ||||
|             } while (true); | ||||
|             break; | ||||
|         } | ||||
|         case LLAMA_GRETYPE_CHAR: | ||||
|         case LLAMA_GRETYPE_CHAR_NOT: | ||||
|         case LLAMA_GRETYPE_CHAR_ANY: | ||||
|             if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { | ||||
|                 // only add the stack if it's not a duplicate of one we already have | ||||
|                 new_stacks.emplace_back(stack); | ||||
|             } | ||||
|             break; | ||||
|         default: | ||||
|             // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range | ||||
|             // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on | ||||
|             // those | ||||
|             GGML_ABORT("fatal error"); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void llama_grammar_advance_stack_memo( | ||||
|         const llama_grammar_rules  & rules, | ||||
|         const llama_grammar_stack  & stack, | ||||
|               llama_grammar_stacks & new_stacks) { | ||||
|  | ||||
|     llama_grammar_stacks advanced_stacks; | ||||
|     // Look if stack is already in memory | ||||
|     auto it = llama_grammar_stacks_cache.find(stack); | ||||
|     if (it != llama_grammar_stacks_cache.end()) { | ||||
|            advanced_stacks = it->second; | ||||
|     } else { | ||||
|         // Advance stacks with memorization  | ||||
|         llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks); | ||||
|         llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks)); | ||||
|     } | ||||
|     // Add the advanced stacks to new_stacks avoiding duplicates | ||||
|     for (const auto & new_stack : advanced_stacks) { | ||||
|         if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) { | ||||
|             new_stacks.emplace_back(new_stack); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| } | ||||
|  | ||||
| // transforms a grammar pushdown stack into N possible stacks, all ending | ||||
| // at a character range (terminal element) | ||||
| static void llama_grammar_advance_stack( | ||||
| @@ -844,7 +952,7 @@ void llama_grammar_accept( | ||||
|             if (!llama_grammar_is_end_of_sequence(pos)) { | ||||
|                 new_stack.push_back(pos); | ||||
|             } | ||||
|             llama_grammar_advance_stack(rules, new_stack, stacks_new); | ||||
|             llama_grammar_advance_stack_memo(rules, new_stack, stacks_new); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -911,6 +1019,8 @@ struct llama_grammar * llama_grammar_init_impl( | ||||
|         const llama_grammar_element ** rules, | ||||
|         size_t n_rules, | ||||
|         size_t start_rule_index) { | ||||
|     // Clear stacks cache | ||||
|     llama_grammar_stacks_cache.clear(); | ||||
|     const llama_grammar_element * pos; | ||||
|  | ||||
|     // copy rule definitions into vectors | ||||
| @@ -945,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl( | ||||
|             // if alternate is nonempty, add to stack | ||||
|             stack.push_back(pos); | ||||
|         } | ||||
|         llama_grammar_advance_stack(vec_rules, stack, stacks); | ||||
|         llama_grammar_advance_stack_memo(vec_rules, stack, stacks); | ||||
|         while (!llama_grammar_is_end_of_sequence(pos)) { | ||||
|             // scan to end of alternate def | ||||
|             pos++; | ||||
| @@ -965,6 +1075,8 @@ struct llama_grammar * llama_grammar_init_impl( | ||||
| } | ||||
|  | ||||
| struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { | ||||
|     // Clear stacks cache | ||||
|     llama_grammar_stacks_cache.clear(); | ||||
|     llama_grammar_parser parser; | ||||
|  | ||||
|     // if there is a grammar, parse it | ||||
| @@ -1023,7 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, | ||||
|             // if alternate is nonempty, add to stack | ||||
|             stack.push_back(pos); | ||||
|         } | ||||
|         llama_grammar_advance_stack(vec_rules, stack, stacks); | ||||
|         llama_grammar_advance_stack_memo(vec_rules, stack, stacks); | ||||
|         while (!llama_grammar_is_end_of_sequence(pos)) { | ||||
|             // scan to end of alternate def | ||||
|             pos++; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Clarissa Miranda
					Clarissa Miranda