mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-28 08:31:25 +00:00 
			
		
		
		
	* Fix unicode in grammars (fixes #2501) * add more comments * fix test-llama-grammar
This commit is contained in:
		
							
								
								
									
										161
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										161
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -2077,37 +2077,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co | |||||||
| // grammar - internal | // grammar - internal | ||||||
| // | // | ||||||
|  |  | ||||||
|  | struct llama_partial_utf8 { | ||||||
|  |     uint32_t value;    // bit value so far (unshifted) | ||||||
|  |     int      n_remain; // num bytes remaining; -1 indicates invalid sequence | ||||||
|  | }; | ||||||
|  |  | ||||||
| struct llama_grammar { | struct llama_grammar { | ||||||
|     const std::vector<std::vector<llama_grammar_element>>   rules; |     const std::vector<std::vector<llama_grammar_element>>   rules; | ||||||
|     std::vector<std::vector<const llama_grammar_element *>> stacks; |     std::vector<std::vector<const llama_grammar_element *>> stacks; | ||||||
|  |  | ||||||
|  |     // buffer for partially generated UTF-8 sequence from accepted tokens | ||||||
|  |     llama_partial_utf8                                      partial_utf8; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct llama_grammar_candidate { | struct llama_grammar_candidate { | ||||||
|     size_t           index; |     size_t               index; | ||||||
|     const uint32_t * code_points; |     const uint32_t     * code_points; | ||||||
|  |     llama_partial_utf8   partial_utf8; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // NOTE: assumes valid utf8 (but checks for overrun) | // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as | ||||||
| // adds a terminating 0 for use as pointer | // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. | ||||||
| std::vector<uint32_t> decode_utf8(const char * src) { | std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8( | ||||||
|     static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; |         const char         * src, | ||||||
|  |         llama_partial_utf8   partial_start) { | ||||||
|  |     static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; | ||||||
|     const char          * pos      = src; |     const char          * pos      = src; | ||||||
|     std::vector<uint32_t> code_points; |     std::vector<uint32_t> code_points; | ||||||
|  |     uint32_t              value    = partial_start.value; | ||||||
|  |     int                   n_remain = partial_start.n_remain; | ||||||
|  |  | ||||||
|  |     // continue previous decode, if applicable | ||||||
|  |     while (*pos != 0 && n_remain > 0) { | ||||||
|  |         uint8_t next_byte = static_cast<uint8_t>(*pos); | ||||||
|  |         if ((next_byte >> 6) != 2) { | ||||||
|  |             // invalid sequence, abort | ||||||
|  |             code_points.push_back(0); | ||||||
|  |             return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); | ||||||
|  |         } | ||||||
|  |         value = (value << 6) + (next_byte & 0x3F); | ||||||
|  |         ++pos; | ||||||
|  |         --n_remain; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (partial_start.n_remain > 0 && n_remain == 0) { | ||||||
|  |         code_points.push_back(value); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // decode any subsequent utf-8 sequences, which may end in an incomplete one | ||||||
|     while (*pos != 0) { |     while (*pos != 0) { | ||||||
|         uint8_t  first_byte = static_cast<uint8_t>(*pos); |         uint8_t  first_byte = static_cast<uint8_t>(*pos); | ||||||
|         uint8_t  highbits   = first_byte >> 4; |         uint8_t  highbits   = first_byte >> 4; | ||||||
|         int      len        = lookup[highbits]; |                  n_remain   = lookup[highbits] - 1; | ||||||
|         uint8_t  mask       = (1 << (8 - len)) - 1; |  | ||||||
|         uint32_t value      = first_byte & mask; |         if (n_remain < 0) { | ||||||
|         const char * end    = pos + len; // may overrun! |             // invalid sequence, abort | ||||||
|         ++pos; |             code_points.clear(); | ||||||
|         for ( ; pos < end && *pos != 0; ++pos) { |             code_points.push_back(0); | ||||||
|             value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); |             return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         uint8_t  mask       = (1 << (7 - n_remain)) - 1; | ||||||
|  |                  value      = first_byte & mask; | ||||||
|  |         ++pos; | ||||||
|  |         while (*pos != 0 && n_remain > 0) { | ||||||
|  |             value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); | ||||||
|  |             ++pos; | ||||||
|  |             --n_remain; | ||||||
|  |         } | ||||||
|  |         if (n_remain == 0) { | ||||||
|  |             code_points.push_back(value); | ||||||
|         } |         } | ||||||
|         code_points.push_back(value); |  | ||||||
|     } |     } | ||||||
|     code_points.push_back(0); |     code_points.push_back(0); | ||||||
|     return code_points; |  | ||||||
|  |     return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); | ||||||
| } | } | ||||||
|  |  | ||||||
| // returns true iff pos points to the end of one of the definitions of a rule | // returns true iff pos points to the end of one of the definitions of a rule | ||||||
| @@ -2144,6 +2188,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char( | |||||||
|     return std::make_pair(found == is_positive_char, pos); |     return std::make_pair(found == is_positive_char, pos); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char | ||||||
|  | // range at pos (regular or inverse range) | ||||||
|  | // asserts that pos is pointing to a char range element | ||||||
|  | static bool llama_grammar_match_partial_char( | ||||||
|  |         const llama_grammar_element * pos, | ||||||
|  |         const llama_partial_utf8      partial_utf8) { | ||||||
|  |  | ||||||
|  |     bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; | ||||||
|  |     LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); | ||||||
|  |  | ||||||
|  |     uint32_t partial_value = partial_utf8.value; | ||||||
|  |     int      n_remain      = partial_utf8.n_remain; | ||||||
|  |  | ||||||
|  |     // invalid sequence or 7-bit char split across 2 bytes (overlong) | ||||||
|  |     if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // range of possible code points this partial UTF-8 sequence could complete to | ||||||
|  |     uint32_t low  = partial_value << (n_remain * 6); | ||||||
|  |     uint32_t high = low | ((1 << (n_remain * 6)) - 1); | ||||||
|  |  | ||||||
|  |     if (low == 0) { | ||||||
|  |         if (n_remain == 2) { | ||||||
|  |             low = 1 << 11; | ||||||
|  |         } else if (n_remain == 3) { | ||||||
|  |             low = 1 << 16; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     do { | ||||||
|  |         if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { | ||||||
|  |             // inclusive range, e.g. [a-z] | ||||||
|  |             if (pos->value <= high && low <= pos[1].value) { | ||||||
|  |                 return is_positive_char; | ||||||
|  |             } | ||||||
|  |             pos += 2; | ||||||
|  |         } else { | ||||||
|  |             // exact char match, e.g. [a] or "a" | ||||||
|  |             if (low <= pos->value && pos->value <= high) { | ||||||
|  |                 return is_positive_char; | ||||||
|  |             } | ||||||
|  |             pos += 1; | ||||||
|  |         } | ||||||
|  |     } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); | ||||||
|  |  | ||||||
|  |     return !is_positive_char; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| // transforms a grammar pushdown stack into N possible stacks, all ending | // transforms a grammar pushdown stack into N possible stacks, all ending | ||||||
| // at a character range (terminal element) | // at a character range (terminal element) | ||||||
| static void llama_grammar_advance_stack( | static void llama_grammar_advance_stack( | ||||||
| @@ -2244,8 +2338,11 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ | |||||||
|     std::vector<llama_grammar_candidate> rejects; |     std::vector<llama_grammar_candidate> rejects; | ||||||
|  |  | ||||||
|     if (stack.empty()) { |     if (stack.empty()) { | ||||||
|         // accept nothing; EOS is handled elsewhere |         for (auto tok : candidates) { | ||||||
|         rejects.insert(rejects.end(), candidates.begin(), candidates.end()); |             if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { | ||||||
|  |                 rejects.push_back(tok); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|         return rejects; |         return rejects; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -2253,10 +2350,15 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ | |||||||
|  |  | ||||||
|     std::vector<llama_grammar_candidate> next_candidates; |     std::vector<llama_grammar_candidate> next_candidates; | ||||||
|     for (auto tok : candidates) { |     for (auto tok : candidates) { | ||||||
|         if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { |         if (*tok.code_points == 0) { | ||||||
|             if (tok.code_points[1] != 0) { |             // reached end of full codepoints in token, reject iff it ended in a partial sequence | ||||||
|                 next_candidates.push_back({ tok.index, tok.code_points + 1 }); |             // that cannot satisfy this position in grammar | ||||||
|  |             if (tok.partial_utf8.n_remain != 0 && | ||||||
|  |                     !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { | ||||||
|  |                 rejects.push_back(tok); | ||||||
|             } |             } | ||||||
|  |         } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { | ||||||
|  |             next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); | ||||||
|         } else { |         } else { | ||||||
|             rejects.push_back(tok); |             rejects.push_back(tok); | ||||||
|         } |         } | ||||||
| @@ -2274,7 +2376,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ | |||||||
|  |  | ||||||
|     auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); |     auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); | ||||||
|     for (auto tok : next_rejects) { |     for (auto tok : next_rejects) { | ||||||
|         rejects.push_back({ tok.index, tok.code_points - 1 }); |         rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return rejects; |     return rejects; | ||||||
| @@ -2339,7 +2441,7 @@ struct llama_grammar * llama_grammar_init( | |||||||
|         } |         } | ||||||
|     } while (true); |     } while (true); | ||||||
|  |  | ||||||
|     return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; |     return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_grammar_free(struct llama_grammar * grammar) { | void llama_grammar_free(struct llama_grammar * grammar) { | ||||||
| @@ -2645,8 +2747,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c | |||||||
|  |  | ||||||
|     const llama_token eos = llama_token_eos(); |     const llama_token eos = llama_token_eos(); | ||||||
|  |  | ||||||
|     std::vector<std::vector<uint32_t>>   candidates_decoded; |     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; | ||||||
|     std::vector<llama_grammar_candidate> candidates_grammar; |     std::vector<llama_grammar_candidate>                              candidates_grammar; | ||||||
|  |  | ||||||
|     for (size_t i = 0; i < candidates->size; ++i) { |     for (size_t i = 0; i < candidates->size; ++i) { | ||||||
|         const llama_token id  = candidates->data[i].id; |         const llama_token id  = candidates->data[i].id; | ||||||
| @@ -2658,8 +2760,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c | |||||||
|         } else if (*str == 0) { |         } else if (*str == 0) { | ||||||
|             candidates->data[i].logit = -INFINITY; |             candidates->data[i].logit = -INFINITY; | ||||||
|         } else { |         } else { | ||||||
|             candidates_decoded.push_back(decode_utf8(str)); |             candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8)); | ||||||
|             candidates_grammar.push_back({ i, candidates_decoded.back().data() }); |             candidates_grammar.push_back({ | ||||||
|  |                 i, candidates_decoded.back().first.data(), candidates_decoded.back().second | ||||||
|  |             }); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -2860,11 +2964,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     const char * str = llama_token_to_str(ctx, token); |     const char * str = llama_token_to_str(ctx, token); | ||||||
|  |  | ||||||
|     // Note terminating 0 in decoded string |     // Note terminating 0 in decoded string | ||||||
|     auto code_points = decode_utf8(str); |     const auto   decoded     = decode_utf8(str, grammar->partial_utf8); | ||||||
|  |     const auto & code_points = decoded.first; | ||||||
|     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { |     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | ||||||
|         grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); |         grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); | ||||||
|     } |     } | ||||||
|  |     grammar->partial_utf8 = decoded.second; | ||||||
|     LLAMA_ASSERT(!grammar->stacks.empty()); |     LLAMA_ASSERT(!grammar->stacks.empty()); | ||||||
|  |  | ||||||
|     ctx->t_sample_us += ggml_time_us() - t_start_sample_us; |     ctx->t_sample_us += ggml_time_us() - t_start_sample_us; | ||||||
|   | |||||||
| @@ -199,7 +199,7 @@ int main() | |||||||
|         uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point |         uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point | ||||||
|         cp[0] = 37 + i; |         cp[0] = 37 + i; | ||||||
|         cp[1] = 0; |         cp[1] = 0; | ||||||
|         next_candidates[i] = {i, cp}; |         next_candidates[i] = {i, cp, {}}; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = { |     std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Evan Jones
					Evan Jones