mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : improve BERT tokenization (#5740)
* implement nfd for stripping accents in wpm tokenizer * sort nfd map; reuse iterator * use builtin tolower * add locale include * Simplify to_lower cases Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
		
							
								
								
									
										137
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										137
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -68,10 +68,12 @@ | ||||
| #include <cstdio> | ||||
| #include <cstring> | ||||
| #include <ctime> | ||||
| #include <cwctype> | ||||
| #include <forward_list> | ||||
| #include <fstream> | ||||
| #include <functional> | ||||
| #include <initializer_list> | ||||
| #include <locale> | ||||
| #include <map> | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
| @@ -8941,37 +8943,46 @@ struct llm_tokenizer_wpm { | ||||
|     } | ||||
|  | ||||
|     std::vector<std::string> preprocess(const std::string & text) { | ||||
|         std::string ori_str = normalize(text); | ||||
|         uint64_t ori_size = ori_str.size(); | ||||
|         // normalalization form D | ||||
|         std::vector<uint32_t> codepoints = codepoints_from_utf8(text); | ||||
|         std::vector<uint32_t> nfd_codepoints; | ||||
|         for (uint32_t code : codepoints) { | ||||
|             auto it = nfd_map.find(code); | ||||
|             if (it != nfd_map.end()) { | ||||
|                 for (uint32_t c : it->second) { | ||||
|                     nfd_codepoints.push_back(c); | ||||
|                 } | ||||
|             } else { | ||||
|                 nfd_codepoints.push_back(code); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // single punct / single symbol / single digit | ||||
|         // baseline: add whitespace on the left and right of punct and chinese characters | ||||
|         std::vector<std::string> words; | ||||
|         // strip accents, strip control, uniformize whitespace, | ||||
|         // to lowercase, pad chinese characters, pad punctuation | ||||
|         std::string new_str = ""; | ||||
|         uint64_t i = 0; | ||||
|         while (i < ori_size) { | ||||
|             int utf_char_len = utf8_len(ori_str[i]); | ||||
|             if ((utf_char_len == 1) && ispunct(ori_str[i])) { | ||||
|                 new_str += " "; | ||||
|                 new_str += ori_str[i]; | ||||
|                 new_str += " "; | ||||
|                 i += 1; | ||||
|         for (uint32_t code : nfd_codepoints) { | ||||
|             int type = codepoint_type(code); | ||||
|             if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { | ||||
|                 continue; | ||||
|             } | ||||
|             else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) { | ||||
|                 new_str += " "; | ||||
|                 new_str += ori_str.substr(i, 3); | ||||
|                 new_str += " "; | ||||
|                 i += 3; | ||||
|             code = to_lower(code); | ||||
|             if (type == CODEPOINT_TYPE_WHITESPACE) { | ||||
|                 code = ' '; | ||||
|             } | ||||
|             else { | ||||
|                 new_str += ori_str[i]; | ||||
|                 i += 1; | ||||
|             std::string s = codepoint_to_utf8(code); | ||||
|             if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) { | ||||
|                 new_str += " "; | ||||
|                 new_str += s; | ||||
|                 new_str += " "; | ||||
|             } else { | ||||
|                 new_str += s; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // split by whitespace | ||||
|         uint64_t l = 0; | ||||
|         uint64_t r = 0; | ||||
|         std::vector<std::string> words; | ||||
|         while (r < new_str.size()) { | ||||
|             // if is whitespace | ||||
|             if (isspace(new_str[r])) { | ||||
| @@ -8989,47 +9000,20 @@ struct llm_tokenizer_wpm { | ||||
|         return words; | ||||
|     } | ||||
|  | ||||
|     std::string normalize(const std::string & text) { | ||||
|         // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 | ||||
|         std::string text2 = strip_accents(text); | ||||
|         for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) { | ||||
|             char c = text2[i]; | ||||
|             if (c >= 'A' && c <= 'Z') { | ||||
|                 text2[i] = c - 'A' + 'a'; | ||||
|             } | ||||
|     uint32_t to_lower(uint32_t code) { | ||||
| #if defined(_WIN32) | ||||
|         if (code > 0xFFFF) { | ||||
|             return code; | ||||
|         } | ||||
|         return text2; | ||||
| #endif | ||||
|         return std::tolower(wchar_t(code), std::locale("en_US.UTF-8")); | ||||
|     } | ||||
|  | ||||
|     bool is_chinese_char(const std::string & str) { | ||||
|         int len = str.length(); | ||||
|         unsigned int codepoint = 0; | ||||
|         int num_bytes = 0; | ||||
|         int i = 0; | ||||
|         unsigned char ch = static_cast<unsigned char>(str[i]); | ||||
|         if (ch <= 0x7f) { | ||||
|             codepoint = ch; | ||||
|             num_bytes = 1; | ||||
|         } else if ((ch >> 5) == 0x06) { | ||||
|             codepoint = ch & 0x1f; | ||||
|             num_bytes = 2; | ||||
|         } else if ((ch >> 4) == 0x0e) { | ||||
|             codepoint = ch & 0x0f; | ||||
|             num_bytes = 3; | ||||
|         } else if ((ch >> 3) == 0x1e) { | ||||
|             codepoint = ch & 0x07; | ||||
|             num_bytes = 4; | ||||
|         } | ||||
|         for (int j = 1; j < num_bytes; ++j) { | ||||
|             if (i + j >= len) { | ||||
|                 return false; // incomplete UTF-8 character | ||||
|             } | ||||
|             unsigned char next_ch = static_cast<unsigned char>(str[i + j]); | ||||
|             if ((next_ch >> 6) != 0x02) { | ||||
|                 return false; // invalid trailing byte | ||||
|             } | ||||
|             codepoint = (codepoint << 6) | (next_ch & 0x3f); | ||||
|         } | ||||
|     bool is_ascii_punct(uint32_t code) { | ||||
|         return code < 256 && ispunct(code); | ||||
|     } | ||||
|  | ||||
|     bool is_chinese_char(uint32_t codepoint) { | ||||
|         if ((codepoint >= 0x4E00  && codepoint <= 0x9FFF)  || | ||||
|             (codepoint >= 0x3400  && codepoint <= 0x4DBF)  || | ||||
|             (codepoint >= 0x20000 && codepoint <= 0x2A6DF) || | ||||
| @@ -9045,41 +9029,6 @@ struct llm_tokenizer_wpm { | ||||
|         return false; | ||||
|     } | ||||
|  | ||||
|     std::string strip_accents(const std::string & input_string) { | ||||
|         std::string resultString; | ||||
|         std::map<std::string, char> accent_map = { | ||||
|             {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'}, | ||||
|             {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'}, | ||||
|             {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'}, | ||||
|             {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'}, | ||||
|             {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'}, | ||||
|             {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'}, | ||||
|             {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'}, | ||||
|             {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'}, | ||||
|             {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'}, | ||||
|         }; | ||||
|  | ||||
|         for (size_t i = 0; i <  input_string.length();) { | ||||
|             int len = utf8_len(input_string[i]); | ||||
|             std::string curChar = input_string.substr(i, len); | ||||
|             auto iter = accent_map.find(curChar); | ||||
|             if (iter != accent_map.end()) { | ||||
|                 resultString += iter->second; | ||||
|             } else { | ||||
|                 resultString += curChar; | ||||
|             } | ||||
|             i += len; | ||||
|         } | ||||
|  | ||||
|         return resultString; | ||||
|     } | ||||
|  | ||||
|     static size_t utf8_len(char src) { | ||||
|         const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; | ||||
|         uint8_t highbits = static_cast<uint8_t>(src) >> 4; | ||||
|         return lookup[highbits]; | ||||
|     } | ||||
|  | ||||
|     const llama_vocab & vocab; | ||||
| }; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Douglas Hanley
					Douglas Hanley