mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	tokenizer : special token handling (#3538)
* Rewrite special token handling from #1931 * shorten param name, add st verification by type * use offsets instead of copy by substr * formatting, remove copying iterator on delete * llama : normalize code-style * swift fix * print pfx/sfx if verb, main: split pfx input sfx * dont add space when using special tokens * minor : comment + spacing --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										290
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										290
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -75,6 +75,7 @@ | ||||
| #include <thread> | ||||
| #include <unordered_map> | ||||
| #include <set> | ||||
| #include <forward_list> | ||||
|  | ||||
| #if defined(_MSC_VER) | ||||
| #pragma warning(disable: 4244 4267) // possible loss of data | ||||
| @@ -1183,6 +1184,8 @@ struct llama_vocab { | ||||
|     std::unordered_map<token, id> token_to_id; | ||||
|     std::vector<token_data>       id_to_token; | ||||
|  | ||||
|     std::unordered_map<token, id> special_tokens_cache; | ||||
|  | ||||
|     std::map<std::pair<std::string, std::string>, int> bpe_ranks; | ||||
|  | ||||
|     // default LLaMA special tokens | ||||
| @@ -2125,7 +2128,7 @@ static void llm_load_hparams( | ||||
| } | ||||
|  | ||||
| // TODO: This should probably be in llama.h | ||||
| static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos); | ||||
| static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); | ||||
| static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); | ||||
|  | ||||
| static void llm_load_vocab( | ||||
| @@ -2241,6 +2244,101 @@ static void llm_load_vocab( | ||||
|     GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); | ||||
|     GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); | ||||
|     GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); | ||||
|  | ||||
|     // build special tokens cache | ||||
|     { | ||||
|         // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, | ||||
|         //  and will always be correctly labeled in 'added_tokens.json' etc. | ||||
|         // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed | ||||
|         //  to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer | ||||
|         //  are special tokens. | ||||
|         // From testing, this appears to corelate 1:1 with special tokens. | ||||
|         // | ||||
|  | ||||
|         // Counting special tokens and verifying in only one direction | ||||
|         //  is sufficient to detect difference in those two sets. | ||||
|         // | ||||
|         uint32_t special_tokens_count_by_type = 0; | ||||
|         uint32_t special_tokens_count_from_verification = 0; | ||||
|  | ||||
|         bool special_tokens_definition_mismatch = false; | ||||
|  | ||||
|         for (const auto & t : vocab.token_to_id) { | ||||
|             const auto & token = t.first; | ||||
|             const auto & id    = t.second; | ||||
|  | ||||
|             // Count all non-normal tokens in the vocab while iterating | ||||
|             if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { | ||||
|                 special_tokens_count_by_type++; | ||||
|             } | ||||
|  | ||||
|             // Skip single character tokens | ||||
|             if (token.length() > 1) { | ||||
|                 bool is_tokenizable = false; | ||||
|  | ||||
|                 // Split token string representation in two, in all possible ways | ||||
|                 //  and check if both halves can be matched to a valid token | ||||
|                 for (unsigned i = 1; i < token.length();) { | ||||
|                     const auto left  = token.substr(0, i); | ||||
|                     const auto right = token.substr(i); | ||||
|  | ||||
|                     // check if we didnt partition in the middle of a utf sequence | ||||
|                     auto utf = utf8_len(left.at(left.length() - 1)); | ||||
|  | ||||
|                     if (utf == 1) { | ||||
|                         if (vocab.token_to_id.find(left)  != vocab.token_to_id.end() && | ||||
|                             vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { | ||||
|                             is_tokenizable = true; | ||||
|                             break; | ||||
|                         } | ||||
|                         i++; | ||||
|                     } else { | ||||
|                         // skip over the rest of multibyte utf sequence | ||||
|                         i += utf - 1; | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 if (!is_tokenizable) { | ||||
|                     // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 | ||||
|                     //  it's faster to re-filter them here, since there are way less candidates now | ||||
|  | ||||
|                     // Calculate a total "utf" length of a token string representation | ||||
|                     size_t utf8_str_len = 0; | ||||
|                     for (unsigned i = 0; i < token.length();) { | ||||
|                         utf8_str_len++; | ||||
|                         i += utf8_len(token.at(i)); | ||||
|                     } | ||||
|  | ||||
|                     // And skip the ones which are one character | ||||
|                     if (utf8_str_len > 1) { | ||||
|                         // At this point what we have left are special tokens only | ||||
|                         vocab.special_tokens_cache[token] = id; | ||||
|  | ||||
|                         // Count manually found special tokens | ||||
|                         special_tokens_count_from_verification++; | ||||
|  | ||||
|                         // If this manually found special token is not marked as such, flag a mismatch | ||||
|                         if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { | ||||
|                             special_tokens_definition_mismatch = true; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { | ||||
|             fprintf(stderr, "%s: warning: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", | ||||
|                 __func__, | ||||
|                 special_tokens_count_from_verification, vocab.id_to_token.size(), | ||||
|                 special_tokens_count_by_type, vocab.id_to_token.size() | ||||
|             ); | ||||
|         } else { | ||||
|             fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n", | ||||
|                 __func__, | ||||
|                 special_tokens_count_from_verification, vocab.id_to_token.size() | ||||
|             ); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { | ||||
| @@ -6464,7 +6562,137 @@ private: | ||||
|     llm_bigram_bpe::queue work_queue; | ||||
| }; | ||||
|  | ||||
| static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) { | ||||
| typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ | ||||
|     FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, | ||||
|     FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT | ||||
| } FRAGMENT_BUFFER_VARIANT_TYPE; | ||||
|  | ||||
| struct fragment_buffer_variant{ | ||||
|     fragment_buffer_variant(llama_vocab::id _token) | ||||
|     : | ||||
|         type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), | ||||
|         token(_token), | ||||
|         raw_text(_dummy), | ||||
|         offset(0), | ||||
|         length(0){} | ||||
|     fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) | ||||
|     : | ||||
|         type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), | ||||
|         token((llama_vocab::id)-1), | ||||
|         raw_text(_raw_text), | ||||
|         offset(_offset), | ||||
|         length(_length){ | ||||
|             GGML_ASSERT( _offset >= 0 ); | ||||
|             GGML_ASSERT( _length >= 1 ); | ||||
|             GGML_ASSERT( offset + length <= raw_text.length() ); | ||||
|         } | ||||
|  | ||||
|     const FRAGMENT_BUFFER_VARIANT_TYPE type; | ||||
|     const llama_vocab::id token; | ||||
|     const std::string _dummy; | ||||
|     const std::string & raw_text; | ||||
|     const uint64_t offset; | ||||
|     const uint64_t length; | ||||
| }; | ||||
|  | ||||
| // #define PRETOKENIZERDEBUG | ||||
|  | ||||
| static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) | ||||
| { | ||||
|     // for each special token | ||||
|     for (const auto & st: vocab.special_tokens_cache) { | ||||
|         const auto & special_token = st.first; | ||||
|         const auto & special_id    = st.second; | ||||
|  | ||||
|         // for each text fragment | ||||
|         std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin(); | ||||
|         while (it != buffer.end()) { | ||||
|             auto & fragment = (*it); | ||||
|  | ||||
|             // if a fragment is text ( not yet processed ) | ||||
|             if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { | ||||
|                 auto * raw_text = &(fragment.raw_text); | ||||
|  | ||||
|                 auto raw_text_base_offset = fragment.offset; | ||||
|                 auto raw_text_base_length = fragment.length; | ||||
|  | ||||
|                 // loop over the text | ||||
|                 while (true) { | ||||
|                     // find the first occurence of a given special token in this fragment | ||||
|                     //  passing offset argument only limit the "search area" but match coordinates | ||||
|                     //  are still relative to the source full raw_text | ||||
|                     auto match = raw_text->find(special_token, raw_text_base_offset); | ||||
|  | ||||
|                     // no occurences found, stop processing this fragment for a given special token | ||||
|                     if (match == std::string::npos) break; | ||||
|  | ||||
|                     // check if match is within bounds of offset <-> length | ||||
|                     if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                     fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); | ||||
| #endif | ||||
|                     auto source = std::distance(buffer.begin(), it); | ||||
|  | ||||
|                     // if match is further than base offset | ||||
|                     //  then we have some text to the left of it | ||||
|                     if (match > raw_text_base_offset) { | ||||
|                         // left | ||||
|                         const int64_t left_reminder_offset = raw_text_base_offset + 0; | ||||
|                         const int64_t left_reminder_length = match - raw_text_base_offset; | ||||
|                         buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                         fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); | ||||
| #endif | ||||
|                         it++; | ||||
|                     } | ||||
|  | ||||
|                     // special token | ||||
|                     buffer.emplace_after(it, special_id); | ||||
|                     it++; | ||||
|  | ||||
|                     // right | ||||
|                     if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { | ||||
|                         const int64_t right_reminder_offset = match + special_token.length(); | ||||
|                         const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); | ||||
|                         buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                         fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); | ||||
| #endif | ||||
|  | ||||
|                         it++; | ||||
|  | ||||
|                         if (source == 0) { | ||||
|                             buffer.erase_after(buffer.before_begin()); | ||||
|                         } else { | ||||
|                             buffer.erase_after(std::next(buffer.begin(), (source-1))); | ||||
|                         } | ||||
|  | ||||
|                         // repeat for the right side | ||||
|                         raw_text_base_offset = right_reminder_offset; | ||||
|                         raw_text_base_length = right_reminder_length; | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                         fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); | ||||
| #endif | ||||
|                     } else { | ||||
|                         if (source == 0) { | ||||
|                             buffer.erase_after(buffer.before_begin()); | ||||
|                         } else { | ||||
|                             buffer.erase_after(std::next(buffer.begin(), (source-1))); | ||||
|                         } | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             it++; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { | ||||
|     std::vector<llama_vocab::id> output; | ||||
|  | ||||
|     // OG tokenizer behavior: | ||||
| @@ -6480,20 +6708,58 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & | ||||
|         return output; | ||||
|     } | ||||
|  | ||||
|     std::forward_list<fragment_buffer_variant> fragment_buffer; | ||||
|     fragment_buffer.emplace_front( raw_text, 0, raw_text.length() ); | ||||
|  | ||||
|     if (special) tokenizer_st_partition( vocab, fragment_buffer ); | ||||
|  | ||||
|     switch (vocab.type) { | ||||
|         case LLAMA_VOCAB_TYPE_SPM: | ||||
|             { | ||||
|                 // without adding this leading whitespace, we do not get the same results as the original tokenizer | ||||
|                 raw_text = " " + raw_text; | ||||
|                 for (const auto & fragment: fragment_buffer) | ||||
|                 { | ||||
|                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) | ||||
|                     { | ||||
|                         // without adding this leading whitespace, we do not get the same results as the original tokenizer | ||||
|  | ||||
|                 llm_tokenizer_spm tokenizer(vocab); | ||||
|                 llama_escape_whitespace(raw_text); | ||||
|                 tokenizer.tokenize(raw_text, output); | ||||
|                         // TODO: It's likely possible to get rid of this string copy entirely | ||||
|                         //  by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer | ||||
|                         //  and passing 'add space prefix' as bool argument | ||||
|                         // | ||||
|                         auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length); | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                         fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); | ||||
| #endif | ||||
|                         llm_tokenizer_spm tokenizer(vocab); | ||||
|                         llama_escape_whitespace(raw_text); | ||||
|                         tokenizer.tokenize(raw_text, output); | ||||
|                     } | ||||
|                     else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) | ||||
|                     { | ||||
|                         output.push_back(fragment.token); | ||||
|                     } | ||||
|                 } | ||||
|             } break; | ||||
|         case LLAMA_VOCAB_TYPE_BPE: | ||||
|             { | ||||
|                 llm_tokenizer_bpe tokenizer(vocab); | ||||
|                 tokenizer.tokenize(raw_text, output); | ||||
|                 for (const auto & fragment: fragment_buffer) | ||||
|                 { | ||||
|                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) | ||||
|                     { | ||||
|                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                         fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); | ||||
| #endif | ||||
|                         llm_tokenizer_bpe tokenizer(vocab); | ||||
|                         tokenizer.tokenize(raw_text, output); | ||||
|                     } | ||||
|                     else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) | ||||
|                     { | ||||
|                         output.push_back(fragment.token); | ||||
|                     } | ||||
|                 } | ||||
|             } break; | ||||
|     } | ||||
|  | ||||
| @@ -9407,15 +9673,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) { | ||||
|     return ctx->model.vocab.special_eot_id; | ||||
| } | ||||
|  | ||||
|  | ||||
| int llama_tokenize( | ||||
|     const struct llama_model * model, | ||||
|                   const char * text, | ||||
|                          int   text_len, | ||||
|                  llama_token * tokens, | ||||
|                          int   n_max_tokens, | ||||
|                         bool   add_bos) { | ||||
|     auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); | ||||
|                         bool   add_bos, | ||||
|                         bool   special) { | ||||
|     auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); | ||||
|  | ||||
|     if (n_max_tokens < (int) res.size()) { | ||||
|         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 staviq
					staviq