mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : tokenizer fixes (#2549)
* Merge tokenizer fixes into the gguf branch. * Add test vocabularies
This commit is contained in:
		
							
								
								
									
										93
									
								
								convert.py
									
									
									
									
									
								
							
							
						
						
									
										93
									
								
								convert.py
									
									
									
									
									
								
							| @@ -238,21 +238,57 @@ class Params: | ||||
|         return params | ||||
|  | ||||
|  | ||||
| class SentencePieceVocab: | ||||
|     def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None: | ||||
|         self.vocabtype = vocabtype | ||||
|         if self.vocabtype == "bpe": | ||||
|           self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) | ||||
| class BpeVocab: | ||||
|     def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: | ||||
|         self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) | ||||
|         added_tokens: Dict[str, int] | ||||
|         if fname_added_tokens is not None: | ||||
|             added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) | ||||
|         else: | ||||
|             added_tokens = {} | ||||
|         vocab_size: int = len(self.bpe_tokenizer) | ||||
|         expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) | ||||
|         actual_ids = sorted(added_tokens.values()) | ||||
|         if expected_ids != actual_ids: | ||||
|             raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}") | ||||
|         items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) | ||||
|         self.added_tokens_list = [text for (text, idx) in items] | ||||
|         self.vocab_size_base: int = vocab_size | ||||
|         self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) | ||||
|         self.fname_tokenizer = fname_tokenizer | ||||
|         self.fname_added_tokens = fname_added_tokens | ||||
|  | ||||
|     def bpe_tokens(self) -> Iterable[Tuple[bytes, float]]: | ||||
|         tokenizer = self.bpe_tokenizer | ||||
|         from transformers.models.gpt2 import tokenization_gpt2 | ||||
|         byte_encoder = tokenization_gpt2.bytes_to_unicode() | ||||
|         byte_decoder = {v: k for k, v in byte_encoder.items()} | ||||
|         for i, item in enumerate(tokenizer): | ||||
|             text: bytes = item.encode("utf-8") | ||||
|             score: float = -i | ||||
|             yield text, score | ||||
|  | ||||
|     def added_tokens(self) -> Iterable[Tuple[bytes, float]]: | ||||
|         for text in self.added_tokens_list: | ||||
|             score = -1000.0 | ||||
|             yield text.encode("utf-8"), score | ||||
|  | ||||
|     def all_tokens(self) -> Iterable[Tuple[bytes, float]]: | ||||
|         yield from self.bpe_tokens() | ||||
|         yield from self.added_tokens() | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" | ||||
|  | ||||
|  | ||||
| class SentencePieceVocab: | ||||
|     def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: | ||||
|         self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) | ||||
|         added_tokens: Dict[str, int] | ||||
|         if fname_added_tokens is not None: | ||||
|             added_tokens = json.load(open(fname_added_tokens)) | ||||
|             added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) | ||||
|         else: | ||||
|             added_tokens = {} | ||||
|         if self.vocabtype == "bpe": | ||||
|           vocab_size: int = len(self.sentencepiece_tokenizer) | ||||
|         else: | ||||
|         vocab_size: int = self.sentencepiece_tokenizer.vocab_size() | ||||
|         expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) | ||||
|         actual_ids = sorted(added_tokens.values()) | ||||
| @@ -267,30 +303,9 @@ class SentencePieceVocab: | ||||
|  | ||||
|     def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: | ||||
|         tokenizer = self.sentencepiece_tokenizer | ||||
|         if self.vocabtype == "bpe": | ||||
|           from transformers.models.gpt2 import tokenization_gpt2 | ||||
|           byte_encoder = tokenization_gpt2.bytes_to_unicode() | ||||
|           byte_decoder = {v: k for k, v in byte_encoder.items()} | ||||
|           for i, item in enumerate(tokenizer): | ||||
|             text: bytes | ||||
|             text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) | ||||
|             score: float = -i | ||||
|             yield text, score | ||||
|         else: | ||||
|         for i in range(tokenizer.vocab_size()): | ||||
|               text: bytes | ||||
|               if tokenizer.is_unknown(i): | ||||
|                   text = " \u2047 ".encode("utf-8") | ||||
|               elif tokenizer.is_control(i): | ||||
|                   text = b"" | ||||
|               elif tokenizer.is_byte(i): | ||||
|             piece = tokenizer.id_to_piece(i) | ||||
|                   if len(piece) != 6: | ||||
|                       raise Exception(f"Invalid token: {piece}") | ||||
|                   byte_value = int(piece[3:-1], 16) | ||||
|                   text = struct.pack("B", byte_value) | ||||
|               else: | ||||
|                   text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") | ||||
|             text: bytes = piece.encode("utf-8") | ||||
|             score: float = tokenizer.get_score(i) | ||||
|             yield text, score | ||||
|  | ||||
| @@ -319,7 +334,7 @@ class GGMLVocab: | ||||
|         return f"<GGMLVocab with {self.vocab_size} tokens>" | ||||
|  | ||||
|  | ||||
| Vocab = Union[SentencePieceVocab, GGMLVocab] | ||||
| Vocab = Union[BpeVocab, SentencePieceVocab, GGMLVocab] | ||||
|  | ||||
|  | ||||
| def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: | ||||
| @@ -1044,7 +1059,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc | ||||
| def check_vocab_size(params: Params, vocab: Vocab) -> None: | ||||
|     if params.n_vocab != vocab.vocab_size: | ||||
|         # GGMLVocab comes from the same file as the model so shouldn't mismatch: | ||||
|         assert isinstance(vocab, SentencePieceVocab) | ||||
|         assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) | ||||
|         if params.n_vocab == vocab.vocab_size_base: | ||||
|             print("Ignoring added_tokens.json since model matches vocab size without it.") | ||||
|             vocab.added_tokens_list = [] | ||||
| @@ -1093,7 +1108,7 @@ class OutputFile: | ||||
|     @staticmethod | ||||
|     def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: | ||||
|         of = OutputFile(fname_out) | ||||
|         params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) | ||||
|         params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0, n_kv_head=None) | ||||
|         of = OutputFile(fname_out) | ||||
|         of.write_file_header(params, file_type=GGMLFileType.AllF32) | ||||
|         of.write_vocab(vocab) | ||||
| @@ -1228,7 +1243,7 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel: | ||||
|     return {name: model[name] for name in TENSORS_LIST if name in model} | ||||
|  | ||||
|  | ||||
| def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: | ||||
| def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]: | ||||
|     print(f"vocabtype: {vocabtype}") | ||||
|     # Be extra-friendly and accept either a file or a directory.  Also, if it's | ||||
|     # a directory, it might be the model directory, and tokenizer.model might | ||||
| @@ -1250,8 +1265,12 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: | ||||
|                 "if it's in another directory, pass the directory as --vocab-dir") | ||||
|     added_tokens_path = path.parent / "added_tokens.json" | ||||
|     print(f"Loading vocab file {path}") | ||||
|     return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, | ||||
|                               vocabtype) | ||||
|     if vocabtype == "bpe": | ||||
|         return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) | ||||
|     elif vocabtype == "spm": | ||||
|         return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) | ||||
|     else: | ||||
|         raise ValueError(f"Unsupported vocabulary type {vocabtype}") | ||||
|  | ||||
|  | ||||
| def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: | ||||
|   | ||||
| @@ -633,17 +633,6 @@ std::string gpt_random_prompt(std::mt19937 & rng) { | ||||
|     return "The"; | ||||
| } | ||||
|  | ||||
| // TODO: not great allocating this every time | ||||
| std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { | ||||
|     // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars | ||||
|     std::vector<llama_token> res(text.size() + (int) add_bos); | ||||
|     const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); | ||||
|     assert(n >= 0); | ||||
|     res.resize(n); | ||||
|  | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { | ||||
|     auto lparams = llama_context_default_params(); | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #define LLAMA_API_CPP // TODO: eliminate me | ||||
| #include "llama.h" | ||||
|  | ||||
| #include <string> | ||||
| @@ -100,12 +101,6 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params); | ||||
|  | ||||
| std::string gpt_random_prompt(std::mt19937 & rng); | ||||
|  | ||||
| // | ||||
| // Vocab utils | ||||
| // | ||||
|  | ||||
| std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos); | ||||
|  | ||||
| // | ||||
| // Model utils | ||||
| // | ||||
|   | ||||
| @@ -67,7 +67,7 @@ int main(int argc, char ** argv) { | ||||
|         fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); | ||||
|         fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); | ||||
|         for (int i = 0; i < (int) embd_inp.size(); i++) { | ||||
|             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); | ||||
|             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str()); | ||||
|         } | ||||
|         fprintf(stderr, "\n"); | ||||
|     } | ||||
|   | ||||
| @@ -191,10 +191,6 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // tokenize the prompt | ||||
|     std::vector<llama_token> embd_inp; | ||||
|  | ||||
|     // Add a space in front of the first character to match OG llama tokenizer behavior | ||||
|     params.prompt.insert(0, 1, ' '); | ||||
|  | ||||
|     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { | ||||
|         embd_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||
|     } else { | ||||
| @@ -278,7 +274,7 @@ int main(int argc, char ** argv) { | ||||
|         fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); | ||||
|         fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); | ||||
|         for (int i = 0; i < (int) embd_inp.size(); i++) { | ||||
|             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); | ||||
|             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str()); | ||||
|         } | ||||
|  | ||||
|         if (ctx_guidance) { | ||||
| @@ -286,14 +282,14 @@ int main(int argc, char ** argv) { | ||||
|             fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); | ||||
|             fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); | ||||
|             for (int i = 0; i < (int) guidance_inp.size(); i++) { | ||||
|                 fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); | ||||
|                 fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str()); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (params.n_keep > 0) { | ||||
|         fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); | ||||
|             for (int i = 0; i < params.n_keep; i++) { | ||||
|                 fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i])); | ||||
|                 fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str()); | ||||
|             } | ||||
|             fprintf(stderr, "'\n"); | ||||
|         } | ||||
| @@ -662,7 +658,7 @@ int main(int argc, char ** argv) { | ||||
|         // display text | ||||
|         if (input_echo) { | ||||
|             for (auto id : embd) { | ||||
|                 printf("%s", llama_token_to_str(ctx, id)); | ||||
|                 printf("%s", llama_token_to_str(ctx, id).c_str()); | ||||
|             } | ||||
|             fflush(stdout); | ||||
|         } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| #include "ggml.h" | ||||
| #include "build-info.h" | ||||
|  | ||||
| #define LLAMA_API_CPP // TODO: eliminate me | ||||
| #define LLAMA_API_INTERNAL | ||||
| #include "llama.h" | ||||
|  | ||||
|   | ||||
| @@ -45,9 +45,8 @@ int main(int argc, char ** argv) { | ||||
|         llama_free_model(model); | ||||
|         return 1; | ||||
|     } | ||||
|     auto tokens = std::vector<llama_token>(params.n_ctx); | ||||
|     auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); | ||||
|  | ||||
|     auto tokens = llama_tokenize(ctx, params.prompt.c_str(), true); | ||||
|     auto n_prompt_tokens = tokens.size(); | ||||
|     if (n_prompt_tokens < 1) { | ||||
|         fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); | ||||
|         llama_free(ctx); | ||||
| @@ -92,7 +91,7 @@ int main(int argc, char ** argv) { | ||||
|         auto next_token_str = llama_token_to_str(ctx, next_token); | ||||
|         last_n_tokens_data.push_back(next_token); | ||||
|  | ||||
|         printf("%s", next_token_str); | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_free(ctx); | ||||
| @@ -152,7 +151,7 @@ int main(int argc, char ** argv) { | ||||
|         auto next_token_str = llama_token_to_str(ctx2, next_token); | ||||
|         last_n_tokens_data.push_back(next_token); | ||||
|  | ||||
|         printf("%s", next_token_str); | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_free(ctx2); | ||||
|   | ||||
| @@ -62,7 +62,7 @@ int main(int argc, char ** argv) { | ||||
|     fprintf(stderr, "\n\n"); | ||||
|  | ||||
|     for (auto id : tokens_list) { | ||||
|         fprintf(stderr, "%s", llama_token_to_str(ctx, id)); | ||||
|         fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str()); | ||||
|     } | ||||
|  | ||||
|     fflush(stderr); | ||||
| @@ -109,7 +109,7 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // print the new token : | ||||
|         printf("%s", llama_token_to_str(ctx, new_token_id)); | ||||
|         printf("%s", llama_token_to_str(ctx, new_token_id).c_str()); | ||||
|         fflush(stdout); | ||||
|  | ||||
|         // push this new token for next evaluation | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| #include "ggml.h" | ||||
| #include "common.h" | ||||
| #include "llama.h" | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
| @@ -1961,7 +1962,7 @@ void print_matrix(struct ggml_tensor * probs) { | ||||
|  | ||||
|  | ||||
| void print_token(struct llama_context * ctx, llama_token token) { | ||||
|     printf("%s", llama_token_to_str(ctx, token)); | ||||
|     printf("%s", llama_token_to_str(ctx, token).c_str()); | ||||
| } | ||||
|  | ||||
| void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) { | ||||
| @@ -2188,11 +2189,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto | ||||
|     f.read_raw(buf.data(), f.size); | ||||
|     buf[f.size] = '\0'; | ||||
|  | ||||
|     out.resize(buf.size()); | ||||
|  | ||||
|     int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false); | ||||
|     if (n_tokens >= 0) { | ||||
|         out.resize(n_tokens); | ||||
|     int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); | ||||
|     if (n_tokens < 0) { | ||||
|         out.resize(-n_tokens); | ||||
|         llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); | ||||
|     } | ||||
|  | ||||
|     bool verify = false; | ||||
| @@ -2200,17 +2200,17 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto | ||||
|         const char * in  = buf.data(); | ||||
|         const char * end = buf.data() + buf.size(); | ||||
|         for (int i = 0; i < (int) out.size(); ++i) { | ||||
|             const char * s = llama_token_to_str(lctx, out[i]); | ||||
|             int len = strlen(s); | ||||
|             std::string s = llama_token_to_str(lctx, out[i]); | ||||
|             int len = s.length(); | ||||
|             if (in >= end) { | ||||
|                 printf("%s: unexpected end of original text.\n", __func__); | ||||
|                 break; | ||||
|             } | ||||
|             const bool matches = (strncmp(in, s, len) == 0); | ||||
|             const bool matches = (strncmp(in, s.c_str(), len) == 0); | ||||
|             if (matches) { | ||||
|                 in += len; | ||||
|             } else { | ||||
|                 printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s); | ||||
|                 printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str()); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
							
								
								
									
										326
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										326
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -7,6 +7,7 @@ | ||||
| #endif | ||||
|  | ||||
| #include "llama-util.h" | ||||
| #define LLAMA_API_CPP // TODO: eliminate me | ||||
| #include "llama.h" | ||||
|  | ||||
| #include "ggml.h" | ||||
| @@ -575,6 +576,7 @@ struct llama_file_loader { | ||||
|             float score = 0.0f; | ||||
|             file.read_raw(&score, sizeof(score)); | ||||
|  | ||||
|             GGML_ASSERT(vocab.token_to_id.find(word) == vocab.token_to_id.end()); | ||||
|             vocab.token_to_id[word] = i; | ||||
|  | ||||
|             auto & tok_score = vocab.id_to_token[i]; | ||||
| @@ -1060,6 +1062,11 @@ static void llama_model_load_internal( | ||||
|     std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap)); | ||||
|  | ||||
|     vocab = std::move(ml->file_loader->vocab); | ||||
|  | ||||
|     if (vocab_only) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     model.hparams = ml->file_loader->hparams; | ||||
|     model.n_gpu_layers = n_gpu_layers; | ||||
|     llama_file_version file_version = ml->file_loader->file_version; | ||||
| @@ -1141,10 +1148,6 @@ static void llama_model_load_internal( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (vocab_only) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     auto & ctx = model.ctx; | ||||
|  | ||||
|     size_t ctx_size; | ||||
| @@ -1940,6 +1943,105 @@ static bool llama_eval_internal( | ||||
| // tokenizer | ||||
| // | ||||
|  | ||||
| static std::string llama_vocab_type(const llama_vocab& vocab) { | ||||
|     return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; | ||||
| } | ||||
|  | ||||
| static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return token >= 259; | ||||
|     else if(llama_vocab_type(vocab) == "bpe") | ||||
|         return token >= 95; | ||||
|     else | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return token == 0; | ||||
|     else | ||||
|         // TODO: improve? | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return token == 1 || token == 2; | ||||
|     else | ||||
|         // TODO: improve? | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return token == 1; | ||||
|     else | ||||
|         // TODO: improve? | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return token == 2; | ||||
|     else | ||||
|         // TODO: improve? | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) { | ||||
|     // TODO: improve? | ||||
|     return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { | ||||
|     // TODO: improve? | ||||
|     return false; | ||||
| } | ||||
|  | ||||
| static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return 3 <= token && token < 259; | ||||
|     else if(llama_vocab_type(vocab) == "bpe") | ||||
|         return 1 <= token && token < 95; | ||||
|     else | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { | ||||
|     if(llama_vocab_type(vocab) == "spm") | ||||
|         return byte + 3; | ||||
|     else if(llama_vocab_type(vocab) == "bpe") | ||||
|         return byte + 32; | ||||
|     else | ||||
|         return false; | ||||
| } | ||||
|  | ||||
| static std::string llama_escape_whitespace(const std::string& text) { | ||||
|     std::string result; | ||||
|     bool escaping = false; | ||||
|     result += "\xe2\x96\x81"; | ||||
|     for (size_t offs = 0; offs < text.length(); ++offs) { | ||||
|         if (text[offs] == ' ') { | ||||
|             if (!escaping) { | ||||
|                 result += "\xe2\x96\x81"; | ||||
|                 escaping = true; | ||||
|             } | ||||
|         } | ||||
|         else { | ||||
|             escaping = false; | ||||
|             result += text[offs]; | ||||
|         } | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| static std::string llama_unescape_whitespace(const std::string& word) { | ||||
|     if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") { | ||||
|         return std::string(" ") + word.substr(3); | ||||
|     } | ||||
|     return word; | ||||
| } | ||||
|  | ||||
| static size_t utf8_len(char src) { | ||||
|     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | ||||
|     uint8_t highbits = static_cast<uint8_t>(src) >> 4; | ||||
| @@ -1981,10 +2083,11 @@ struct llama_tokenizer { | ||||
|         size_t offs = 0; | ||||
|         while (offs < text.size()) { | ||||
|             llama_sp_symbol sym; | ||||
|             size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); | ||||
|             size_t len = utf8_len(text[offs]); | ||||
|             GGML_ASSERT(offs + len <= text.size()); | ||||
|             sym.text = text.c_str() + offs; | ||||
|             sym.n = char_len; | ||||
|             offs += char_len; | ||||
|             sym.n = len; | ||||
|             offs += len; | ||||
|             sym.prev = index - 1; | ||||
|             sym.next = offs == text.size() ? -1 : index + 1; | ||||
|             index++; | ||||
| @@ -2029,23 +2132,36 @@ struct llama_tokenizer { | ||||
|  | ||||
|         for (int i = 0; i != -1; i = symbols_[i].next) { | ||||
|             auto & symbol = symbols_[i]; | ||||
|             auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); | ||||
|  | ||||
|             if (token == vocab_.token_to_id.end()) { | ||||
|                 // output any symbols that did not form tokens as bytes. | ||||
|                 for (int j = 0; j < (int) symbol.n; ++j) { | ||||
|                     // NOTE: old version, before #2420 - not sure what are the implications of this | ||||
|                     //llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3; | ||||
|                     llama_vocab::id token_id = vocab_.token_to_id.at(std::string(1, symbol.text[j])); | ||||
|                     output.push_back(token_id); | ||||
|                 } | ||||
|             } else { | ||||
|                 output.push_back((*token).second); | ||||
|             } | ||||
|             resegment(symbol, output); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| private: | ||||
|     void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) { | ||||
|         auto text = std::string(symbol.text, symbol.n); | ||||
|         auto token = vocab_.token_to_id.find(text); | ||||
|  | ||||
|         // Do we need to support is_unused? | ||||
|         if (token != vocab_.token_to_id.end()) { | ||||
|             output.push_back((*token).second); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         const auto p = rev_merge.find(text); | ||||
|  | ||||
|         if (p == rev_merge.end()) { | ||||
|             // output any symbols that did not form tokens as bytes. | ||||
|             for (int j = 0; j < (int)symbol.n; ++j) { | ||||
|                 llama_vocab::id token_id = llama_byte_to_char(vocab_, symbol.text[j]); | ||||
|                 output.push_back(token_id); | ||||
|             } | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         resegment(symbols_[p->second.first], output); | ||||
|         resegment(symbols_[p->second.second], output); | ||||
|     } | ||||
|  | ||||
|     void try_add_bigram(int left, int right) { | ||||
|         if (left == -1 || right == -1) { | ||||
|             return; | ||||
| @@ -2070,18 +2186,22 @@ private: | ||||
|         bigram.score = tok_score.score; | ||||
|         bigram.size = text.size(); | ||||
|         work_queue_.push(bigram); | ||||
|  | ||||
|         // Do we need to support is_unused? | ||||
|         rev_merge[text] = std::make_pair(left, right); | ||||
|     } | ||||
|  | ||||
|     const llama_vocab & vocab_; | ||||
|     std::vector<llama_sp_symbol> symbols_; | ||||
|     llama_sp_bigram::queue work_queue_; | ||||
|     std::map<std::string, std::pair<int, int> > rev_merge; | ||||
| }; | ||||
|  | ||||
| static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { | ||||
| static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) { | ||||
|     llama_tokenizer tokenizer(vocab); | ||||
|     std::vector<llama_vocab::id> output; | ||||
|  | ||||
|     if (text.empty()) { | ||||
|     if (raw_text.empty()) { | ||||
|         return output; | ||||
|     } | ||||
|  | ||||
| @@ -2089,6 +2209,13 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co | ||||
|         output.push_back(llama_token_bos()); | ||||
|     } | ||||
|  | ||||
|     std::string text; | ||||
|     if (escape) { | ||||
|         text = llama_escape_whitespace(raw_text); | ||||
|     } else { | ||||
|         text = raw_text; | ||||
|     } | ||||
|  | ||||
|     tokenizer.tokenize(text, output); | ||||
|     return output; | ||||
| } | ||||
| @@ -2670,15 +2797,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c | ||||
|  | ||||
|     for (size_t i = 0; i < candidates->size; ++i) { | ||||
|         const llama_token id  = candidates->data[i].id; | ||||
|         const char *      str = llama_token_to_str(ctx, id); | ||||
|         std::string       str = llama_token_to_str(ctx, id); | ||||
|         if (id == eos) { | ||||
|             if (!allow_eos) { | ||||
|                 candidates->data[i].logit = -INFINITY; | ||||
|             } | ||||
|         } else if (*str == 0) { | ||||
|         } else if (str.empty()) { | ||||
|             candidates->data[i].logit = -INFINITY; | ||||
|         } else { | ||||
|             candidates_decoded.push_back(decode_utf8(str)); | ||||
|             candidates_decoded.push_back(decode_utf8(str.c_str())); | ||||
|             candidates_grammar.push_back({ i, candidates_decoded.back().data() }); | ||||
|         } | ||||
|     } | ||||
| @@ -2879,9 +3006,9 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar | ||||
|         LLAMA_ASSERT(false); | ||||
|     } | ||||
|  | ||||
|     const char * str = llama_token_to_str(ctx, token); | ||||
|     std::string str = llama_token_to_str(ctx, token); | ||||
|     // Note terminating 0 in decoded string | ||||
|     auto code_points = decode_utf8(str); | ||||
|     auto code_points = decode_utf8(str.c_str()); | ||||
|     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | ||||
|         grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); | ||||
|     } | ||||
| @@ -4132,7 +4259,8 @@ int llama_tokenize_with_model( | ||||
|                  llama_token * tokens, | ||||
|                          int   n_max_tokens, | ||||
|                         bool   add_bos) { | ||||
|     auto res = llama_tokenize(model->vocab, text, add_bos); | ||||
|     auto escape = llama_vocab_type(model->vocab) == "spm"; | ||||
|     auto res = llama_tokenize(model->vocab, text, add_bos, escape); | ||||
|  | ||||
|     if (n_max_tokens < (int) res.size()) { | ||||
|         LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); | ||||
| @@ -4155,6 +4283,62 @@ int llama_tokenize( | ||||
|     return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); | ||||
| } | ||||
|  | ||||
| std::vector<llama_token> llama_tokenize( | ||||
|         struct llama_context * ctx, | ||||
|            const std::string & text, | ||||
|                         bool   add_bos) { | ||||
|     int length = text.length() + add_bos; | ||||
|     std::vector<llama_token> result(length); | ||||
|     length = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); | ||||
|     if (length < 0) { | ||||
|         result.resize(-length); | ||||
|         int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); | ||||
|         assert(check == -length); | ||||
|         GGML_UNUSED(check); | ||||
|     } else { | ||||
|         result.resize(length); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| int llama_tokenize_bpe( | ||||
|         struct llama_context * ctx, | ||||
|                   const char * text, | ||||
|                  llama_token * tokens, | ||||
|                          int   n_max_tokens, | ||||
|                         bool   add_bos) { | ||||
|     auto res = llama_tokenize(ctx->model.vocab, text, add_bos, false); | ||||
|  | ||||
|     if (n_max_tokens < (int) res.size()) { | ||||
|         LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); | ||||
|         return -((int) res.size()); | ||||
|     } | ||||
|  | ||||
|     for (size_t i = 0; i < res.size(); i++) { | ||||
|         tokens[i] = res[i]; | ||||
|     } | ||||
|  | ||||
|     return res.size(); | ||||
| } | ||||
|  | ||||
| std::vector<llama_token> llama_tokenize_bpe( | ||||
|         struct llama_context * ctx, | ||||
|            const std::string & text, | ||||
|                         bool   add_bos) { | ||||
|     int length = text.length() + add_bos; | ||||
|     std::vector<llama_token> result(length); | ||||
|     length = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); | ||||
|     if (length < 0) { | ||||
|         result.resize(-length); | ||||
|         int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); | ||||
|         assert(check == -length); | ||||
|         GGML_UNUSED(check); | ||||
|     } else { | ||||
|         result.resize(length); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| int llama_n_vocab_from_model(const struct llama_model * model) { | ||||
|     return model->vocab.id_to_token.size(); | ||||
| } | ||||
| @@ -4208,16 +4392,88 @@ float * llama_get_embeddings(struct llama_context * ctx) { | ||||
|     return ctx->embedding.data(); | ||||
| } | ||||
|  | ||||
| const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) { | ||||
|     if (token >= llama_n_vocab_from_model(model)) { | ||||
|         return nullptr; | ||||
| int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) { | ||||
|     if (0 <= token && token < llama_n_vocab_from_model(model)) { | ||||
|         if (llama_is_normal_token(model->vocab, token)) { | ||||
|             std::string result = model->vocab.id_to_token[token].tok; | ||||
|             if(llama_vocab_type(model->vocab) == "spm") { | ||||
|                 result = llama_unescape_whitespace(result); | ||||
|             } | ||||
|  | ||||
|     return model->vocab.id_to_token[token].tok.c_str(); | ||||
|             if(result.length() > length) { | ||||
|                 return - result.length(); | ||||
|             } | ||||
|             strcpy(str, result.c_str()); | ||||
|             return result.length(); | ||||
|         } else if (llama_is_unknown_token(model->vocab, token)) { | ||||
|             if(3 > length) { | ||||
|                 return -3; | ||||
|             } | ||||
|             strcpy(str, "\xe2\x96\x85"); | ||||
|             return 3; | ||||
|         } else if (llama_is_control_token(model->vocab, token)) { | ||||
|             ; | ||||
|         } else if (llama_is_byte_token(model->vocab, token)) { | ||||
|             if(1 > length) { | ||||
|                 return -1; | ||||
|             } | ||||
|             str[0] = llama_byte_to_char(model->vocab, token); | ||||
|             str[1] = 0x00; | ||||
|             return 1; | ||||
|         } | ||||
|     } | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { | ||||
|     return llama_token_to_str_with_model(&ctx->model, token); | ||||
| int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * str, int length) { | ||||
|     return llama_token_to_str_with_model(&ctx->model, token, str, length); | ||||
| } | ||||
|  | ||||
| std::string llama_token_to_str( | ||||
|         const struct llama_context * ctx, | ||||
|                        llama_token   token) { | ||||
|     std::string result; | ||||
|     int length = 8; | ||||
|     result.resize(length); | ||||
|     length = llama_token_to_str(ctx, token, (char *)result.data(), result.length()); | ||||
|     if (length < 0) { | ||||
|         result.resize(-length); | ||||
|         int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length()); | ||||
|         assert(check == -length); | ||||
|         GGML_UNUSED(check); | ||||
|     } else { | ||||
|         result.resize(length); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) { | ||||
|     if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) { | ||||
|         std::string result = ctx->model.vocab.id_to_token[token].tok; | ||||
|         if (result.length() > length) { | ||||
|             return - result.length(); | ||||
|         } | ||||
|         strcpy(str, result.c_str()); | ||||
|         return result.length(); | ||||
|     } | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| std::string llama_token_to_str_bpe( | ||||
|     const struct llama_context * ctx, | ||||
|                    llama_token   token) { | ||||
|     std::string result; | ||||
|     int length = 8; | ||||
|     result.resize(length); | ||||
|     length = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length()); | ||||
|     if (length < 0) { | ||||
|         result.resize(-length); | ||||
|         int check = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length()); | ||||
|         assert(check == -length); | ||||
|         GGML_UNUSED(check); | ||||
|     } else { | ||||
|         result.resize(length); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| llama_token llama_token_bos() { | ||||
|   | ||||
							
								
								
									
										60
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								llama.h
									
									
									
									
									
								
							| @@ -336,6 +336,13 @@ extern "C" { | ||||
|                              int   n_max_tokens, | ||||
|                             bool   add_bos); | ||||
|  | ||||
|     LLAMA_API int llama_tokenize_bpe( | ||||
|             struct llama_context * ctx, | ||||
|                       const char * text, | ||||
|                      llama_token * tokens, | ||||
|                              int   n_max_tokens, | ||||
|                             bool   add_bos); | ||||
|  | ||||
|     LLAMA_API int llama_tokenize_with_model( | ||||
|         const struct llama_model * model, | ||||
|                       const char * text, | ||||
| @@ -377,14 +384,23 @@ extern "C" { | ||||
|     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); | ||||
|  | ||||
|     // Token Id -> String. Uses the vocabulary in the provided context | ||||
|     LLAMA_API const char * llama_token_to_str( | ||||
|     LLAMA_API int llama_token_to_str( | ||||
|             const struct llama_context * ctx, | ||||
|                            llama_token   token); | ||||
|                            llama_token   token, | ||||
|                                   char * str, | ||||
|                                   int    length); | ||||
|  | ||||
|     LLAMA_API const char * llama_token_to_str_with_model( | ||||
|     LLAMA_API int llama_token_to_str_bpe( | ||||
|             const struct llama_context * ctx, | ||||
|                            llama_token   token, | ||||
|                                   char * str, | ||||
|                                   int    length); | ||||
|  | ||||
|     LLAMA_API int llama_token_to_str_with_model( | ||||
|               const struct llama_model * model, | ||||
|                            llama_token   token); | ||||
|  | ||||
|                            llama_token   token, | ||||
|                                   char * str, | ||||
|                                   int    length); | ||||
|     // Special tokens | ||||
|     LLAMA_API llama_token llama_token_bos();  // beginning-of-sentence | ||||
|     LLAMA_API llama_token llama_token_eos();  // end-of-sentence | ||||
| @@ -472,15 +488,43 @@ extern "C" { | ||||
| } | ||||
| #endif | ||||
|  | ||||
| // Internal API to be implemented by llama.cpp and used by tests/benchmarks only | ||||
| #ifdef LLAMA_API_INTERNAL | ||||
| // C++ API, will be moving to common.h soon (TM) | ||||
| #ifdef LLAMA_API_CPP | ||||
|  | ||||
| #include <vector> | ||||
| #include <string> | ||||
|  | ||||
| // | ||||
| // Vocab utils | ||||
| // | ||||
|  | ||||
| std::vector<llama_token> llama_tokenize( | ||||
|         struct llama_context * ctx, | ||||
|            const std::string & text, | ||||
|                         bool   add_bos); | ||||
|  | ||||
| std::vector<llama_token> llama_tokenize_bpe( | ||||
|         struct llama_context * ctx, | ||||
|            const std::string & text, | ||||
|                         bool   add_bos); | ||||
|  | ||||
| std::string llama_token_to_str( | ||||
|         const struct llama_context * ctx, | ||||
|                        llama_token   token); | ||||
|  | ||||
| std::string llama_token_to_str_bpe( | ||||
|     const struct llama_context * ctx, | ||||
|                    llama_token   token); | ||||
|  | ||||
| // Internal API to be implemented by llama.cpp and used by tests/benchmarks only | ||||
| #ifdef LLAMA_API_INTERNAL | ||||
|  | ||||
| struct ggml_tensor; | ||||
|  | ||||
| const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx); | ||||
|  | ||||
| #endif | ||||
| #endif // LLAMA_API_CPP | ||||
|  | ||||
| #endif // LLAMA_API_INTERNAL | ||||
|  | ||||
| #endif // LLAMA_H | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								models/ggml-vocab-aquila.bin
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/ggml-vocab-aquila.bin
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								models/ggml-vocab-llama.bin
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/ggml-vocab-llama.bin
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -1,4 +1,19 @@ | ||||
| function(llama_add_test source) | ||||
| function(llama_build_executable source) | ||||
|     get_filename_component(TEST_TARGET ${source} NAME_WE) | ||||
|     add_executable(${TEST_TARGET} ${source}) | ||||
|     install(TARGETS ${TEST_TARGET} RUNTIME) | ||||
|     target_link_libraries(${TEST_TARGET} PRIVATE llama) | ||||
| endfunction() | ||||
|  | ||||
| function(llama_test_executable name source) | ||||
|     get_filename_component(TEST_TARGET ${source} NAME_WE) | ||||
|     # add_executable(${TEST_TARGET} ${source}) | ||||
|     # install(TARGETS ${TEST_TARGET} RUNTIME) | ||||
|     # target_link_libraries(${TEST_TARGET} PRIVATE llama) | ||||
|     add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN}) | ||||
| endfunction() | ||||
|  | ||||
| function(llama_build_and_test_executable source) | ||||
|     get_filename_component(TEST_TARGET ${source} NAME_WE) | ||||
|     add_executable(${TEST_TARGET} ${source}) | ||||
|     install(TARGETS ${TEST_TARGET} RUNTIME) | ||||
| @@ -6,11 +21,15 @@ function(llama_add_test source) | ||||
|     add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN}) | ||||
| endfunction() | ||||
|  | ||||
| # llama_add_test(test-double-float.cpp) # SLOW | ||||
| llama_add_test(test-quantize-fns.cpp) | ||||
| llama_add_test(test-quantize-perf.cpp) | ||||
| llama_add_test(test-sampling.cpp) | ||||
| llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) | ||||
| llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) | ||||
| llama_add_test(test-grad0.cpp) # SLOW | ||||
| # llama_add_test(test-opt.cpp) # SLOW | ||||
| # llama_build_and_test_executable(test-double-float.cpp) # SLOW | ||||
| llama_build_and_test_executable(test-quantize-fns.cpp) | ||||
| llama_build_and_test_executable(test-quantize-perf.cpp) | ||||
| llama_build_and_test_executable(test-sampling.cpp) | ||||
| llama_build_executable(test-tokenizer-0.cpp) | ||||
| llama_test_executable(test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.bin) | ||||
| llama_build_executable(test-tokenizer-1.cpp) | ||||
| llama_test_executable(test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.bin) | ||||
| llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.bin) | ||||
| llama_build_and_test_executable(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) | ||||
| llama_build_and_test_executable(test-grad0.cpp) # SLOW | ||||
| # llama_build_and_test_executable(test-opt.cpp) # SLOW | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| #define LLAMA_API_CPP // TODO: eliminate me | ||||
| #include "llama.h" | ||||
|  | ||||
| #include <cstdio> | ||||
| @@ -5,15 +6,39 @@ | ||||
| #include <map> | ||||
| #include <vector> | ||||
|  | ||||
| static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) { | ||||
|     std::string result; | ||||
|     for (int i = 0; i < tokens.size(); ++i) { | ||||
|         result += llama_token_to_str(ctx, tokens[i]); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| static const std::map<std::string, std::vector<llama_token>> & k_tests() | ||||
| { | ||||
|     static std::map<std::string, std::vector<llama_token>> _k_tests = { | ||||
|         { "Hello World",        { 1,  10994,   2787, }, }, | ||||
|         { " Hello World",       { 1,  15043,   2787, }, }, | ||||
|         { " Hello World!",      { 1,  15043,   2787,  29991, }, }, | ||||
|         { " this is 🦙.cpp",    { 1,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, }, | ||||
|         { "w048 7tuijk dsdfhu", { 1,  29893,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, }, | ||||
|         { "нещо на Български",  { 1,    821,   4851,    665,   1386,  29713,   1305, }, }, | ||||
|         { " ",                      {1,    259, }, }, | ||||
|         { "\t",                     { 1,    29871,   12, }, }, | ||||
|         { "\n",                     { 1,    29871,   13, }, }, | ||||
|         { "\t\n",                   { 1,    29871,   12,     13, }, }, | ||||
|         { "Hello world",            { 1,  15043,   3186, }, }, | ||||
|         { " Hello world",           { 1,  29871,  15043,   3186, }, }, | ||||
|         { "Hello World",            { 1,  15043,   2787, }, }, | ||||
|         { " Hello World",           { 1,  29871,  15043,   2787, }, }, | ||||
|         { " Hello World!",          { 1,  29871,  15043,   2787,  29991, }, }, | ||||
|         { " this is 🦙.cpp",        { 1,  29871,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, }, | ||||
|         { "w048 7tuijk dsdfhu",     { 1,    281,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, }, | ||||
|         { "нещо на Български",      { 1,   1538,   4851,    665,   1386,  29713,   1305, }, }, | ||||
|         { "កាន់តែពិសេសអាចខលចេញ",   { 1,  29871,  31849,  31324,  31934,    228,    162,    142,    228,    161,     | ||||
|                                         146,    228,    162,    133,    228,    161,    153,    228,    161,    186,   | ||||
|                                         31708,    228,    162,    132,  31708,    228,    161,    165,  31324,    228,     | ||||
|                                         161,    136,    228,    161,    132,    228,    161,    158,    228,    161,     | ||||
|                                         136,    228,    162,    132,    228,    161,    140, }, }, | ||||
|         { "🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", | ||||
|             { 1,  29871,    243,    162,    157,    131,    313,   8945,  29897,  29871,     | ||||
|               243,    162,    155,    185,  30722,    243,    162,    143,    174,  30598,     | ||||
|               313,  20787,    953,   3848,    275,  16125,    630,  29897,  29871,  31681,     | ||||
|               313,   6194,    953,  29877,   2397,    393,    756,    967,   1914,   5993,  29897, }, }, | ||||
|      }; | ||||
|     return _k_tests; | ||||
| }; | ||||
| @@ -65,9 +90,9 @@ int main(int argc, char **argv) { | ||||
|     } | ||||
|  | ||||
|     for (const auto & test_kv : k_tests()) { | ||||
|         std::vector<llama_token> res(test_kv.first.size()); | ||||
|         const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true); | ||||
|         res.resize(n); | ||||
|         std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first.c_str(), true); | ||||
|         fprintf(stderr, "%s : '%s' tokenized to '%s'\n",  | ||||
|             __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str()); | ||||
|  | ||||
|         bool correct = res.size() == test_kv.second.size(); | ||||
|  | ||||
|   | ||||
							
								
								
									
										122
									
								
								tests/test-tokenizer-1.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								tests/test-tokenizer-1.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,122 @@ | ||||
| #define LLAMA_API_CPP // TODO: eliminate me | ||||
| #include "llama.h" | ||||
|  | ||||
| #include <cassert> | ||||
| #include <cstdio> | ||||
| #include <cstring> | ||||
| #include <string> | ||||
| #include <codecvt> | ||||
| #include <map> | ||||
| #include <vector> | ||||
|  | ||||
| static std::string vocab_type(llama_context* ctx) { | ||||
|     return llama_n_vocab(ctx) == 32000 ? "spm": "bpe"; | ||||
| } | ||||
|  | ||||
| static std::string escape_whitespace(const std::string& text) { | ||||
|     std::string result; | ||||
|     bool escaping = false; | ||||
|     result += "\xe2\x96\x81"; | ||||
|     for (size_t offs = 0; offs < text.length(); ++offs) { | ||||
|         if (text[offs] == ' ') { | ||||
|             if (!escaping) { | ||||
|                 result += "\xe2\x96\x81"; | ||||
|                 escaping = true; | ||||
|             } | ||||
|         } | ||||
|         else { | ||||
|             escaping = false; | ||||
|             result += text[offs]; | ||||
|         } | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) { | ||||
|     std::string result; | ||||
|     for (int i = 0; i < tokens.size(); ++i) { | ||||
|         result += llama_token_to_str(ctx, tokens[i]); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| int main(int argc, char **argv) { | ||||
|     if (argc < 2) { | ||||
|         fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]); | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     const std::string fname = argv[1]; | ||||
|  | ||||
|     fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); | ||||
|  | ||||
|     llama_model * model; | ||||
|     llama_context * ctx; | ||||
|  | ||||
|     llama_backend_init(false); | ||||
|  | ||||
|     // load the vocab | ||||
|     { | ||||
|         auto lparams = llama_context_default_params(); | ||||
|  | ||||
|         lparams.vocab_only = true; | ||||
|  | ||||
|         model = llama_load_model_from_file(fname.c_str(), lparams); | ||||
|  | ||||
|         if (model == NULL) { | ||||
|             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); | ||||
|             return 1; | ||||
|         } | ||||
|  | ||||
|         ctx = llama_new_context_with_model(model, lparams); | ||||
|  | ||||
|         if (ctx == NULL) { | ||||
|             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); | ||||
|             llama_free_model(model); | ||||
|             return 1; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     const int n_vocab = llama_n_vocab(ctx); | ||||
|  | ||||
|     for (int i = 0; i < n_vocab; ++i) { | ||||
|         std::string forward = llama_token_to_str_bpe(ctx, i); | ||||
|         std::vector<llama_token> tokens = llama_tokenize_bpe(ctx, forward, false); | ||||
|         if (tokens.size() == 1) { | ||||
|             if (i != tokens[0]) { | ||||
|                 std::string backward = llama_token_to_str(ctx, tokens[0]); | ||||
|                 fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n",  | ||||
|                     __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); | ||||
|                 return 2; | ||||
|             } | ||||
|         } else { | ||||
|             if ((vocab_type(ctx) == "spm" && i <= 258) ||  | ||||
|                 (vocab_type(ctx) == "bpe" && (i == 0 || i >= 100000))) { | ||||
|                 fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",  | ||||
|                     __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); | ||||
|             } else { | ||||
|                 fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",  | ||||
|                     __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); | ||||
|                 return 2; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     std::wstring_convert<typename std::codecvt_utf8<wchar_t>, wchar_t> converter; | ||||
|     for (wchar_t ch = 0x0000; ch < 0xffff; ++ch) { | ||||
|         std::wstring wstr(1, ch); | ||||
|         std::string str = converter.to_bytes(wstr); | ||||
|         std::vector<llama_token> tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false); | ||||
|         if (tokens.size() == 1) { | ||||
|             fprintf(stderr, "%s : info: %s tokenized to %d \n",  | ||||
|                 __func__, str.c_str(), tokens[0]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     llama_free_model(model); | ||||
|     llama_free(ctx); | ||||
|  | ||||
|     llama_backend_free(); | ||||
|  | ||||
|     return 0; | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 goerch
					goerch