mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : minor llama_grammar refactoring
ggml-ci
This commit is contained in:
		| @@ -11,20 +11,15 @@ | |||||||
| static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { | static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { | ||||||
|     const auto cpts = unicode_cpts_from_utf8(input_str); |     const auto cpts = unicode_cpts_from_utf8(input_str); | ||||||
|  |  | ||||||
|     const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar); |     auto & stacks_cur = llama_grammar_get_stacks(grammar); | ||||||
|           llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); |  | ||||||
|           llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); |  | ||||||
|  |  | ||||||
|     size_t pos = 0; |     size_t pos = 0; | ||||||
|     for (const auto & cpt : cpts) { |     for (const auto & cpt : cpts) { | ||||||
|         const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy |         llama_grammar_accept(grammar, cpt); | ||||||
|  |  | ||||||
|         llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); |  | ||||||
|  |  | ||||||
|         if (stacks_cur.empty()) { |         if (stacks_cur.empty()) { | ||||||
|             error_pos = pos; |             error_pos = pos; | ||||||
|             error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; |             error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; | ||||||
|             stacks_cur = stacks_prev; |  | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|         ++pos; |         ++pos; | ||||||
| @@ -83,7 +78,8 @@ int main(int argc, char** argv) { | |||||||
|  |  | ||||||
|     llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); |     llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); | ||||||
|     if (grammar == nullptr) { |     if (grammar == nullptr) { | ||||||
|         throw std::runtime_error("Failed to initialize llama_grammar"); |         fprintf(stdout, "Failed to initialize llama_grammar\n"); | ||||||
|  |         return 1; | ||||||
|     } |     } | ||||||
|     // Read the input file |     // Read the input file | ||||||
|     std::string input_str; |     std::string input_str; | ||||||
|   | |||||||
| @@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) | |||||||
|     return grammar->stacks; |     return grammar->stacks; | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { | void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { | ||||||
|     return grammar->stacks_cache; |     llama_grammar_stacks stacks_new; | ||||||
| } |     stacks_new.reserve(grammar->stacks.size()); | ||||||
|  |  | ||||||
| void llama_grammar_accept( |     for (const auto & stack : grammar->stacks) { | ||||||
|         const llama_grammar_rules  & rules, |  | ||||||
|         const llama_grammar_stacks & stacks, |  | ||||||
|         const uint32_t               chr, |  | ||||||
|               llama_grammar_stacks & stacks_new, |  | ||||||
|               llama_grammar_stacks_cache & stacks_cache) { |  | ||||||
|     stacks_new.clear(); |  | ||||||
|     stacks_new.reserve(stacks.size()); |  | ||||||
|  |  | ||||||
|     for (const auto & stack : stacks) { |  | ||||||
|         if (stack.empty()) { |         if (stack.empty()) { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
| @@ -944,9 +935,11 @@ void llama_grammar_accept( | |||||||
|             if (!llama_grammar_is_end_of_sequence(pos)) { |             if (!llama_grammar_is_end_of_sequence(pos)) { | ||||||
|                 new_stack.push_back(pos); |                 new_stack.push_back(pos); | ||||||
|             } |             } | ||||||
|             llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); |             llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     grammar->stacks = std::move(stacks_new); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_grammar_candidates llama_grammar_reject_candidates_for_stack( | llama_grammar_candidates llama_grammar_reject_candidates_for_stack( | ||||||
| @@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl( | |||||||
|     // Important: vec_rules has to be moved here, not copied, because stacks contains |     // Important: vec_rules has to be moved here, not copied, because stacks contains | ||||||
|     // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar |     // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar | ||||||
|     // then the pointers would be invalidated when the local vec_rules goes out of scope. |     // then the pointers would be invalidated when the local vec_rules goes out of scope. | ||||||
|     return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; |     return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { | struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { | ||||||
| @@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, | |||||||
|     // Important: vec_rules has to be moved here, not copied, because stacks contains |     // Important: vec_rules has to be moved here, not copied, because stacks contains | ||||||
|     // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar |     // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar | ||||||
|     // then the pointers would be invalidated when the local vec_rules goes out of scope. |     // then the pointers would be invalidated when the local vec_rules goes out of scope. | ||||||
|     return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; |     return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_grammar_free_impl(struct llama_grammar * grammar) { | void llama_grammar_free_impl(struct llama_grammar * grammar) { | ||||||
| @@ -1153,7 +1146,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { | |||||||
| } | } | ||||||
|  |  | ||||||
| struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { | struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { | ||||||
|     llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; |     llama_grammar * result = new llama_grammar { | ||||||
|  |         grammar.vocab, | ||||||
|  |         grammar.rules, | ||||||
|  |         grammar.stacks, | ||||||
|  |         grammar.stacks_cache, | ||||||
|  |         grammar.partial_utf8, | ||||||
|  |     }; | ||||||
|  |  | ||||||
|     // redirect elements in stacks to point to new rules |     // redirect elements in stacks to point to new rules | ||||||
|     for (size_t is = 0; is < result->stacks.size(); is++) { |     for (size_t is = 0; is < result->stacks.size(); is++) { | ||||||
| @@ -1161,7 +1160,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra | |||||||
|             for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { |             for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { | ||||||
|                 for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { |                 for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { | ||||||
|                     if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { |                     if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { | ||||||
|                          result->stacks[is][ie]  =  &result->rules[ir0][ir1]; |                         result->stacks[is][ie] =  &result->rules[ir0][ir1]; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token | |||||||
|     const auto   decoded     = decode_utf8(piece, grammar.partial_utf8); |     const auto   decoded     = decode_utf8(piece, grammar.partial_utf8); | ||||||
|     const auto & code_points = decoded.first; |     const auto & code_points = decoded.first; | ||||||
|  |  | ||||||
|     llama_grammar_stacks stacks_new; |  | ||||||
|  |  | ||||||
|     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { |     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | ||||||
|         llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); |         llama_grammar_accept(&grammar, *it); | ||||||
|         grammar.stacks = std::move(stacks_new); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     grammar.partial_utf8 = decoded.second; |     grammar.partial_utf8 = decoded.second; | ||||||
|   | |||||||
| @@ -71,20 +71,15 @@ struct VectorPointerHash { | |||||||
|  |  | ||||||
| using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>; | using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>; | ||||||
|  |  | ||||||
|  | // TODO: remove, needed for tests atm | ||||||
| const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar); | const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar); | ||||||
|       llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar); |       llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar); | ||||||
|       llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(      struct llama_grammar * grammar); |  | ||||||
|  |  | ||||||
| // takes a set of possible pushdown stacks on a grammar, which are required to | // takes a set of possible pushdown stacks on a grammar, which are required to | ||||||
| // be positioned at a character range (see `llama_grammar_advance_stack`), and | // be positioned at a character range (see `llama_grammar_advance_stack`), and | ||||||
| // produces the N possible stacks if the given char is accepted at those | // produces the N possible stacks if the given char is accepted at those | ||||||
| // positions | // positions | ||||||
| void llama_grammar_accept( | void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); | ||||||
|         const llama_grammar_rules  & rules, |  | ||||||
|         const llama_grammar_stacks & stacks, |  | ||||||
|                           uint32_t   chr, |  | ||||||
|               llama_grammar_stacks & stacks_new, |  | ||||||
|               llama_grammar_stacks_cache & stacks_cache); |  | ||||||
|  |  | ||||||
| std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( | std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( | ||||||
|         const llama_grammar_rules      & rules, |         const llama_grammar_rules      & rules, | ||||||
| @@ -128,10 +123,11 @@ struct llama_grammar { | |||||||
|     const llama_grammar_rules  rules;  // TODO: shared ptr |     const llama_grammar_rules  rules;  // TODO: shared ptr | ||||||
|           llama_grammar_stacks stacks; |           llama_grammar_stacks stacks; | ||||||
|  |  | ||||||
|     // buffer for partially generated UTF-8 sequence from accepted tokens |  | ||||||
|     llama_partial_utf8 partial_utf8; |  | ||||||
|     // cache N possible stacks from a stack |     // cache N possible stacks from a stack | ||||||
|     llama_grammar_stacks_cache stacks_cache; |     llama_grammar_stacks_cache stacks_cache; | ||||||
|  |  | ||||||
|  |     // buffer for partially generated UTF-8 sequence from accepted tokens | ||||||
|  |     llama_partial_utf8 partial_utf8; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // | // | ||||||
|   | |||||||
| @@ -32,14 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { | |||||||
| static bool match_string(const std::string & input, llama_grammar * grammar) { | static bool match_string(const std::string & input, llama_grammar * grammar) { | ||||||
|     const auto cpts = unicode_cpts_from_utf8(input); |     const auto cpts = unicode_cpts_from_utf8(input); | ||||||
|  |  | ||||||
|     const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar); |     auto & stacks_cur = llama_grammar_get_stacks(grammar); | ||||||
|           llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); |  | ||||||
|           llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); |  | ||||||
|  |  | ||||||
|     for (const auto & cpt : cpts) { |     for (const auto & cpt : cpts) { | ||||||
|         const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy |         llama_grammar_accept(grammar, cpt); | ||||||
|  |  | ||||||
|         llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); |  | ||||||
|  |  | ||||||
|         if (stacks_cur.empty()) { |         if (stacks_cur.empty()) { | ||||||
|             // no stacks means that the grammar failed to match at this point |             // no stacks means that the grammar failed to match at this point | ||||||
| @@ -64,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, | |||||||
|     auto * grammar = build_grammar(grammar_str); |     auto * grammar = build_grammar(grammar_str); | ||||||
|  |  | ||||||
|     // Save the original grammar stacks so that we can reset after every new string we want to test |     // Save the original grammar stacks so that we can reset after every new string we want to test | ||||||
|     const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); |     const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy | ||||||
|  |  | ||||||
|     llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); |     llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -113,12 +113,10 @@ int main() | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_grammar * grammar = NULL; |  | ||||||
|     std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); |     std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); | ||||||
|  |  | ||||||
|     grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); |     llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); | ||||||
|     if (grammar == nullptr) |     if (grammar == nullptr) { | ||||||
|     { |  | ||||||
|         throw std::runtime_error("Failed to initialize llama_grammar"); |         throw std::runtime_error("Failed to initialize llama_grammar"); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov