mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	speculative : initial example
This commit is contained in:
		| @@ -317,6 +317,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||
|                 break; | ||||
|             } | ||||
|             params.model = argv[i]; | ||||
|         } else if (arg == "-md" || arg == "--model-draft") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.model_draft = argv[i]; | ||||
|         } else if (arg == "-a" || arg == "--alias") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
| @@ -669,6 +675,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | ||||
|     fprintf(stdout, "  --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n"); | ||||
|     fprintf(stdout, "  -m FNAME, --model FNAME\n"); | ||||
|     fprintf(stdout, "                        model path (default: %s)\n", params.model.c_str()); | ||||
|     fprintf(stdout, "  -md FNAME, --model-draft FNAME\n"); | ||||
|     fprintf(stdout, "                        draft model for speculative sampling (default: %s)\n", params.model.c_str()); | ||||
|     fprintf(stdout, "  -ld LOGDIR, --logdir LOGDIR\n"); | ||||
|     fprintf(stdout, "                        path under which to save YAML logs (no logging if unset)\n"); | ||||
|     fprintf(stdout, "\n"); | ||||
| @@ -832,6 +840,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| // | ||||
| // Sampling utils | ||||
| // | ||||
|  | ||||
| llama_token llama_sample_token( | ||||
|                   struct llama_context * ctx, | ||||
|                   struct llama_context * ctx_guidance, | ||||
|                   struct llama_grammar * grammar, | ||||
|                const struct gpt_params & params, | ||||
|         const std::vector<llama_token> & last_tokens, | ||||
|          std::vector<llama_token_data> & candidates, | ||||
|                                    int   idx) { | ||||
|     const int n_ctx   = llama_n_ctx(ctx); | ||||
|     const int n_vocab = llama_n_vocab(ctx); | ||||
|  | ||||
|     const float   temp            = params.temp; | ||||
|     const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k; | ||||
|     const float   top_p           = params.top_p; | ||||
|     const float   tfs_z           = params.tfs_z; | ||||
|     const float   typical_p       = params.typical_p; | ||||
|     const int32_t repeat_last_n   = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; | ||||
|     const float   repeat_penalty  = params.repeat_penalty; | ||||
|     const float   alpha_presence  = params.presence_penalty; | ||||
|     const float   alpha_frequency = params.frequency_penalty; | ||||
|     const int     mirostat        = params.mirostat; | ||||
|     const float   mirostat_tau    = params.mirostat_tau; | ||||
|     const float   mirostat_eta    = params.mirostat_eta; | ||||
|     const bool    penalize_nl     = params.penalize_nl; | ||||
|  | ||||
|     llama_token id = 0; | ||||
|  | ||||
|     float * logits = llama_get_logits(ctx) + idx * n_vocab; | ||||
|  | ||||
|     // Apply params.logit_bias map | ||||
|     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { | ||||
|         logits[it->first] += it->second; | ||||
|     } | ||||
|  | ||||
|     candidates.clear(); | ||||
|     for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|         candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|     } | ||||
|  | ||||
|     llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; | ||||
|  | ||||
|     if (ctx_guidance) { | ||||
|         llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); | ||||
|     } | ||||
|  | ||||
|     // apply penalties | ||||
|     if (!last_tokens.empty()) { | ||||
|         const float nl_logit = logits[llama_token_nl(ctx)]; | ||||
|         const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); | ||||
|  | ||||
|         llama_sample_repetition_penalty(ctx, &cur_p, | ||||
|                 last_tokens.data() + last_tokens.size() - last_n_repeat, | ||||
|                 last_n_repeat, repeat_penalty); | ||||
|         llama_sample_frequency_and_presence_penalties(ctx, &cur_p, | ||||
|                 last_tokens.data() + last_tokens.size() - last_n_repeat, | ||||
|                 last_n_repeat, alpha_frequency, alpha_presence); | ||||
|  | ||||
|         if (!penalize_nl) { | ||||
|             for (size_t idx = 0; idx < cur_p.size; idx++) { | ||||
|                 if (cur_p.data[idx].id == llama_token_nl(ctx)) { | ||||
|                     cur_p.data[idx].logit = nl_logit; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (grammar != NULL) { | ||||
|         llama_sample_grammar(ctx, &cur_p, grammar); | ||||
|     } | ||||
|  | ||||
|     if (temp <= 0) { | ||||
|         // Greedy sampling | ||||
|         id = llama_sample_token_greedy(ctx, &cur_p); | ||||
|     } else { | ||||
|         if (mirostat == 1) { | ||||
|             static float mirostat_mu = 2.0f * mirostat_tau; | ||||
|             const int mirostat_m = 100; | ||||
|             llama_sample_temperature(ctx, &cur_p, temp); | ||||
|             id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||||
|         } else if (mirostat == 2) { | ||||
|             static float mirostat_mu = 2.0f * mirostat_tau; | ||||
|             llama_sample_temperature(ctx, &cur_p, temp); | ||||
|             id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||||
|         } else { | ||||
|             // Temperature sampling | ||||
|             llama_sample_top_k      (ctx, &cur_p, top_k, 1); | ||||
|             llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1); | ||||
|             llama_sample_typical    (ctx, &cur_p, typical_p, 1); | ||||
|             llama_sample_top_p      (ctx, &cur_p, top_p, 1); | ||||
|             llama_sample_temperature(ctx, &cur_p, temp); | ||||
|  | ||||
|             { | ||||
|                 const int n_top = 10; | ||||
|                 LOG("top %d candidates:\n", n_top); | ||||
|  | ||||
|                 for (int i = 0; i < n_top; i++) { | ||||
|                     const llama_token id = cur_p.data[i].id; | ||||
|                     LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             id = llama_sample_token(ctx, &cur_p); | ||||
|  | ||||
|             LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); | ||||
|         } | ||||
|     } | ||||
|     // printf("`%d`", candidates_p.size); | ||||
|  | ||||
|     if (grammar != NULL) { | ||||
|         llama_grammar_accept_token(ctx, grammar, id); | ||||
|     } | ||||
|  | ||||
|     return id; | ||||
| } | ||||
|  | ||||
| // | ||||
| // YAML utils | ||||
| // | ||||
|  | ||||
| // returns true if successful, false otherwise | ||||
| bool create_directory_with_parents(const std::string & path) { | ||||
| #ifdef _WIN32 | ||||
| @@ -1070,6 +1202,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l | ||||
|     fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); | ||||
|     fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); | ||||
|     fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); | ||||
|     fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); | ||||
|     fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false"); | ||||
|     fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); | ||||
|     fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers); | ||||
|   | ||||
| @@ -63,6 +63,7 @@ struct gpt_params { | ||||
|     float       cfg_scale         = 1.f;   // How strong is guidance | ||||
|  | ||||
|     std::string model             = "models/7B/ggml-model-f16.gguf"; // model path | ||||
|     std::string model_draft       = "";                              // draft model for speculative sampling | ||||
|     std::string model_alias       = "unknown"; // model alias | ||||
|     std::string prompt            = ""; | ||||
|     std::string path_prompt_cache = "";  // path to file for saving/loading prompt eval state | ||||
| @@ -156,6 +157,40 @@ std::string llama_detokenize_bpe( | ||||
|                          llama_context * ctx, | ||||
|         const std::vector<llama_token> & tokens); | ||||
|  | ||||
| // | ||||
| // Sampling utils | ||||
| // | ||||
|  | ||||
| // this is a common sampling function used across the examples for convenience | ||||
| // it can serve as a starting point for implementing your own sampling function | ||||
| // | ||||
| // required: | ||||
| //  - ctx:    context to use for sampling | ||||
| //  - params: sampling parameters | ||||
| // | ||||
| // optional: | ||||
| //  - ctx_guidance:  context to use for classifier-free guidance, ignore if NULL | ||||
| //  - grammar:       grammar to use for sampling, ignore if NULL | ||||
| //  - last_tokens:   needed for repetition penalty, ignore if empty | ||||
| //  - idx:           sample from llama_get_logits(ctx) + idx * n_vocab | ||||
| // | ||||
| // returns: | ||||
| //  - token:      sampled token | ||||
| //  - candidates: vector of candidate tokens | ||||
| // | ||||
| llama_token llama_sample_token( | ||||
|                   struct llama_context * ctx, | ||||
|                   struct llama_context * ctx_guidance, | ||||
|                   struct llama_grammar * grammar, | ||||
|                const struct gpt_params & params, | ||||
|         const std::vector<llama_token> & last_tokens, | ||||
|          std::vector<llama_token_data> & candidates, | ||||
|                                    int   idx = 0); | ||||
|  | ||||
| // | ||||
| // YAML utils | ||||
| // | ||||
|  | ||||
| bool create_directory_with_parents(const std::string & path); | ||||
| void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data); | ||||
| void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data); | ||||
|   | ||||
| @@ -23,6 +23,7 @@ else() | ||||
|     add_subdirectory(train-text-from-scratch) | ||||
|     add_subdirectory(convert-llama2c-to-ggml) | ||||
|     add_subdirectory(simple) | ||||
|     add_subdirectory(speculative) | ||||
|     add_subdirectory(embd-input) | ||||
|     add_subdirectory(llama-bench) | ||||
|     add_subdirectory(beam-search) | ||||
|   | ||||
| @@ -116,7 +116,7 @@ int main(int argc, char ** argv) { | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     log_set_target(log_filename_generator("main", "log")); | ||||
|     LOG_TEE("Log start\n"); | ||||
|     log_dump_cmdline(argc,argv); | ||||
|     log_dump_cmdline(argc, argv); | ||||
| #endif // LOG_DISABLE_LOGS | ||||
|  | ||||
|     // TODO: Dump params ? | ||||
| @@ -425,8 +425,9 @@ int main(int argc, char ** argv) { | ||||
|     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); | ||||
|     LOG_TEE("\n\n"); | ||||
|  | ||||
|     struct llama_grammar * grammar = NULL; | ||||
|     grammar_parser::parse_state parsed_grammar; | ||||
|     llama_grammar *             grammar = NULL; | ||||
|  | ||||
|     if (!params.grammar.empty()) { | ||||
|         parsed_grammar = grammar_parser::parse(params.grammar.c_str()); | ||||
|         // will be empty (default) if there are parse errors | ||||
| @@ -450,8 +451,8 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     // TODO: replace with ring-buffer | ||||
|     std::vector<llama_token> last_n_tokens(n_ctx); | ||||
|     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); | ||||
|     std::vector<llama_token> last_tokens(n_ctx); | ||||
|     std::fill(last_tokens.begin(), last_tokens.end(), 0); | ||||
|  | ||||
|     if (params.interactive) { | ||||
|         const char *control_message; | ||||
| @@ -492,6 +493,11 @@ int main(int argc, char ** argv) { | ||||
|     std::vector<llama_token> embd; | ||||
|     std::vector<llama_token> embd_guidance; | ||||
|  | ||||
|     const int n_vocab = llama_n_vocab(ctx); | ||||
|  | ||||
|     std::vector<llama_token_data> candidates; | ||||
|     candidates.reserve(n_vocab); | ||||
|  | ||||
|     while ((n_remain != 0 && !is_antiprompt) || params.interactive) { | ||||
|         // predict | ||||
|         if (embd.size() > 0) { | ||||
| @@ -529,8 +535,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); | ||||
|  | ||||
|                 // insert n_left/2 tokens at the start of embd from last_n_tokens | ||||
|                 embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); | ||||
|                 // insert n_left/2 tokens at the start of embd from last_tokens | ||||
|                 embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size()); | ||||
|  | ||||
|                 LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); | ||||
|  | ||||
| @@ -629,20 +635,6 @@ int main(int argc, char ** argv) { | ||||
|         embd_guidance.clear(); | ||||
|  | ||||
|         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||||
|             const float   temp            = params.temp; | ||||
|             const int32_t top_k           = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; | ||||
|             const float   top_p           = params.top_p; | ||||
|             const float   tfs_z           = params.tfs_z; | ||||
|             const float   typical_p       = params.typical_p; | ||||
|             const int32_t repeat_last_n   = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; | ||||
|             const float   repeat_penalty  = params.repeat_penalty; | ||||
|             const float   alpha_presence  = params.presence_penalty; | ||||
|             const float   alpha_frequency = params.frequency_penalty; | ||||
|             const int     mirostat        = params.mirostat; | ||||
|             const float   mirostat_tau    = params.mirostat_tau; | ||||
|             const float   mirostat_eta    = params.mirostat_eta; | ||||
|             const bool    penalize_nl     = params.penalize_nl; | ||||
|  | ||||
|             // optionally save the session on first sample (for faster prompt loading next time) | ||||
|             if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { | ||||
|                 need_to_save_session = false; | ||||
| @@ -651,98 +643,12 @@ int main(int argc, char ** argv) { | ||||
|                 LOG("saved session to %s\n", path_session.c_str()); | ||||
|             } | ||||
|  | ||||
|             llama_token id = 0; | ||||
|             const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); | ||||
|  | ||||
|             { | ||||
|                 auto logits  = llama_get_logits(ctx); | ||||
|                 auto n_vocab = llama_n_vocab(ctx); | ||||
|             last_tokens.erase(last_tokens.begin()); | ||||
|             last_tokens.push_back(id); | ||||
|  | ||||
|                 // Apply params.logit_bias map | ||||
|                 for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { | ||||
|                     logits[it->first] += it->second; | ||||
|                 } | ||||
|  | ||||
|                 std::vector<llama_token_data> candidates; | ||||
|                 candidates.reserve(n_vocab); | ||||
|                 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|                     candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|                 } | ||||
|  | ||||
|                 llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; | ||||
|  | ||||
|                 if (ctx_guidance) { | ||||
|                     llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); | ||||
|                 } | ||||
|  | ||||
|                 // Apply penalties | ||||
|                 float nl_logit = logits[llama_token_nl(ctx)]; | ||||
|                 auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); | ||||
|                 llama_sample_repetition_penalty(ctx, &cur_p, | ||||
|                     last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|                     last_n_repeat, repeat_penalty); | ||||
|                 llama_sample_frequency_and_presence_penalties(ctx, &cur_p, | ||||
|                     last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|                     last_n_repeat, alpha_frequency, alpha_presence); | ||||
|                 if (!penalize_nl) { | ||||
|                     for (size_t idx = 0; idx < cur_p.size; idx++) { | ||||
|                         if (cur_p.data[idx].id == llama_token_nl(ctx)) { | ||||
|                             cur_p.data[idx].logit = nl_logit; | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 if (grammar != NULL) { | ||||
|                     llama_sample_grammar(ctx, &cur_p, grammar); | ||||
|                 } | ||||
|  | ||||
|                 if (temp <= 0) { | ||||
|                     // Greedy sampling | ||||
|                     id = llama_sample_token_greedy(ctx, &cur_p); | ||||
|                 } else { | ||||
|                     if (mirostat == 1) { | ||||
|                         static float mirostat_mu = 2.0f * mirostat_tau; | ||||
|                         const int mirostat_m = 100; | ||||
|                         llama_sample_temperature(ctx, &cur_p, temp); | ||||
|                         id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||||
|                     } else if (mirostat == 2) { | ||||
|                         static float mirostat_mu = 2.0f * mirostat_tau; | ||||
|                         llama_sample_temperature(ctx, &cur_p, temp); | ||||
|                         id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||||
|                     } else { | ||||
|                         // Temperature sampling | ||||
|                         llama_sample_top_k      (ctx, &cur_p, top_k, 1); | ||||
|                         llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1); | ||||
|                         llama_sample_typical    (ctx, &cur_p, typical_p, 1); | ||||
|                         llama_sample_top_p      (ctx, &cur_p, top_p, 1); | ||||
|                         llama_sample_temperature(ctx, &cur_p, temp); | ||||
|  | ||||
|                         { | ||||
|                             const int n_top = 10; | ||||
|                             LOG("top %d candidates:\n", n_top); | ||||
|  | ||||
|                             for (int i = 0; i < n_top; i++) { | ||||
|                                 const llama_token id = cur_p.data[i].id; | ||||
|                                 LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         id = llama_sample_token(ctx, &cur_p); | ||||
|  | ||||
|                         LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); | ||||
|                     } | ||||
|                 } | ||||
|                 // printf("`%d`", candidates_p.size); | ||||
|  | ||||
|                 if (grammar != NULL) { | ||||
|                     llama_grammar_accept_token(ctx, grammar, id); | ||||
|                 } | ||||
|  | ||||
|                 last_n_tokens.erase(last_n_tokens.begin()); | ||||
|                 last_n_tokens.push_back(id); | ||||
|  | ||||
|                 LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_n_tokens)); | ||||
|             } | ||||
|             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens)); | ||||
|  | ||||
|             embd.push_back(id); | ||||
|  | ||||
| @@ -758,8 +664,8 @@ int main(int argc, char ** argv) { | ||||
|             LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); | ||||
|             while ((int) embd_inp.size() > n_consumed) { | ||||
|                 embd.push_back(embd_inp[n_consumed]); | ||||
|                 last_n_tokens.erase(last_n_tokens.begin()); | ||||
|                 last_n_tokens.push_back(embd_inp[n_consumed]); | ||||
|                 last_tokens.erase(last_tokens.begin()); | ||||
|                 last_tokens.push_back(embd_inp[n_consumed]); | ||||
|                 ++n_consumed; | ||||
|                 if ((int) embd.size() >= params.n_batch) { | ||||
|                     break; | ||||
| @@ -792,7 +698,7 @@ int main(int argc, char ** argv) { | ||||
|             // check for reverse prompt | ||||
|             if (params.antiprompt.size()) { | ||||
|                 std::string last_output; | ||||
|                 for (auto id : last_n_tokens) { | ||||
|                 for (auto id : last_tokens) { | ||||
|                     last_output += llama_token_to_piece(ctx, id); | ||||
|                 } | ||||
|  | ||||
| @@ -823,7 +729,7 @@ int main(int argc, char ** argv) { | ||||
|             } | ||||
|  | ||||
|             // deal with end of text token in interactive mode | ||||
|             if (last_n_tokens.back() == llama_token_eos(ctx)) { | ||||
|             if (last_tokens.back() == llama_token_eos(ctx)) { | ||||
|                 LOG("found EOS token\n"); | ||||
|  | ||||
|                 if (params.interactive) { | ||||
| @@ -925,7 +831,7 @@ int main(int argc, char ** argv) { | ||||
|                     if (grammar != NULL) { | ||||
|                         llama_grammar_free(grammar); | ||||
|  | ||||
|                         std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules()); | ||||
|                         std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); | ||||
|                         grammar = llama_grammar_init( | ||||
|                             grammar_rules.data(), grammar_rules.size(), | ||||
|                             parsed_grammar.symbol_ids.at("root")); | ||||
|   | ||||
							
								
								
									
										8
									
								
								examples/speculative/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								examples/speculative/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| set(TARGET speculative) | ||||
| add_executable(${TARGET} speculative.cpp) | ||||
| install(TARGETS ${TARGET} RUNTIME) | ||||
| target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||||
| target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||||
| if(TARGET BUILD_INFO) | ||||
|   add_dependencies(${TARGET} BUILD_INFO) | ||||
| endif() | ||||
							
								
								
									
										227
									
								
								examples/speculative/speculative.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								examples/speculative/speculative.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,227 @@ | ||||
| #ifndef _GNU_SOURCE | ||||
| #define _GNU_SOURCE | ||||
| #endif | ||||
|  | ||||
| #include "build-info.h" | ||||
|  | ||||
| #include "common.h" | ||||
| #include "llama.h" | ||||
|  | ||||
| #include <cmath> | ||||
| #include <cstdio> | ||||
| #include <string> | ||||
| #include <vector> | ||||
|  | ||||
| int main(int argc, char ** argv) { | ||||
|     gpt_params params; | ||||
|  | ||||
|     if (gpt_params_parse(argc, argv, params) == false) { | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     if (params.model_draft.empty()) { | ||||
|         fprintf(stderr, "%s: error: --model-draft is required\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     log_set_target(log_filename_generator("speculative", "log")); | ||||
|     LOG_TEE("Log start\n"); | ||||
|     log_dump_cmdline(argc, argv); | ||||
| #endif // LOG_DISABLE_LOGS | ||||
|  | ||||
|     // init llama.cpp | ||||
|     llama_backend_init(params.numa); | ||||
|  | ||||
|     llama_model * model_tgt = NULL; | ||||
|     llama_model * model_dft = NULL; | ||||
|  | ||||
|     llama_context * ctx_tgt = NULL; | ||||
|     llama_context * ctx_dft = NULL; | ||||
|  | ||||
|     // load the target model | ||||
|     params.perplexity = true; // HACK: enable logits_all = true | ||||
|     std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); | ||||
|  | ||||
|     // load the draft model | ||||
|     params.model = params.model_draft; | ||||
|     std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); | ||||
|  | ||||
|     // tokenize the prompt | ||||
|     std::vector<llama_token> inp; | ||||
|     inp = ::llama_tokenize(ctx_tgt, params.prompt, true); | ||||
|  | ||||
|     const int max_context_size     = llama_n_ctx(ctx_tgt); | ||||
|     const int max_tokens_list_size = max_context_size - 4; | ||||
|  | ||||
|     if ((int) inp.size() > max_tokens_list_size) { | ||||
|         fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     fprintf(stderr, "\n\n"); | ||||
|  | ||||
|     for (auto id : inp) { | ||||
|         fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str()); | ||||
|     } | ||||
|  | ||||
|     fflush(stderr); | ||||
|  | ||||
|     // eval the prompt with both models | ||||
|     llama_eval(ctx_tgt,  inp.data(), int(inp.size() - 1), 0, params.n_threads); | ||||
|     llama_eval(ctx_tgt, &inp.back(),      1, inp.size() - 1, params.n_threads); | ||||
|     llama_eval(ctx_dft,  inp.data(),     int(inp.size()), 0, params.n_threads); | ||||
|  | ||||
|     // the 2 models should have the same vocab | ||||
|     const int n_ctx   = llama_n_ctx(ctx_tgt); | ||||
|     const int n_vocab = llama_n_vocab(ctx_tgt); | ||||
|     //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); | ||||
|  | ||||
|     // how many tokens to draft each time | ||||
|     const int n_draft = 16; | ||||
|  | ||||
|     int n_predict = 0; | ||||
|     int n_drafted = 0; | ||||
|     int n_accept  = 0; | ||||
|  | ||||
|     int n_past_tgt = inp.size(); | ||||
|     int n_past_dft = inp.size(); | ||||
|  | ||||
|     std::vector<llama_token> drafted; | ||||
|  | ||||
|     std::vector<llama_token> last_tokens(n_ctx); | ||||
|     std::fill(last_tokens.begin(), last_tokens.end(), 0); | ||||
|  | ||||
|     for (auto & id : inp) { | ||||
|         last_tokens.erase(last_tokens.begin()); | ||||
|         last_tokens.push_back(id); | ||||
|     } | ||||
|  | ||||
|     std::vector<llama_token_data> candidates; | ||||
|     candidates.reserve(n_vocab); | ||||
|  | ||||
|     // used to determine end of generation | ||||
|     bool has_eos = false; | ||||
|  | ||||
|     const auto t_gen_start = ggml_time_us(); | ||||
|  | ||||
|     while (true) { | ||||
|         LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); | ||||
|  | ||||
|         // sample from the drafted tokens if any | ||||
|         int i_dft = 0; | ||||
|         while (true) { | ||||
|             const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); | ||||
|  | ||||
|             last_tokens.erase(last_tokens.begin()); | ||||
|             last_tokens.push_back(id); | ||||
|  | ||||
|             //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens)); | ||||
|  | ||||
|             const std::string token_str = llama_token_to_piece(ctx_tgt, id); | ||||
|             printf("%s", token_str.c_str()); | ||||
|             fflush(stdout); | ||||
|  | ||||
|             if (id == llama_token_eos(ctx_tgt)) { | ||||
|                 has_eos = true; | ||||
|             } | ||||
|  | ||||
|             ++n_predict; | ||||
|  | ||||
|             if (i_dft < (int) drafted.size() && id == drafted[i_dft]) { | ||||
|                 LOG("drafted token %d accepted\n", id); | ||||
|                 ++n_accept; | ||||
|                 ++n_past_tgt; | ||||
|                 ++n_past_dft; | ||||
|                 ++i_dft; | ||||
|  | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             // the drafted token was rejected or we are out of drafted tokens | ||||
|             llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); | ||||
|             ++n_past_dft; | ||||
|  | ||||
|             drafted.clear(); | ||||
|             drafted.push_back(id); | ||||
|  | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         if (n_predict > params.n_predict || has_eos) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         // sample n_draft tokens from the draft model picking the best token | ||||
|         int n_past_cur = n_past_dft; | ||||
|         for (int i = 0; i < n_draft; ++i) { | ||||
|             float * logits = llama_get_logits(ctx_dft); | ||||
|  | ||||
|             candidates.clear(); | ||||
|             for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|                 candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|             } | ||||
|  | ||||
|             llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; | ||||
|  | ||||
|             // computes softmax and sorts the candidates | ||||
|             llama_sample_softmax(ctx_dft, &cur_p); | ||||
|  | ||||
|             for (int i = 0; i < 3; ++i) { | ||||
|                 LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p); | ||||
|             } | ||||
|  | ||||
|             // too low probability, stop drafting | ||||
|             if (cur_p.data[0].p < 2*cur_p.data[1].p) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
|             drafted.push_back(cur_p.data[0].id); | ||||
|             ++n_drafted; | ||||
|  | ||||
|             if (i < n_draft - 1) { | ||||
|                 // evaluate the drafted token on the draft model | ||||
|                 llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); | ||||
|                 ++n_past_cur; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // evaluate the target model on the drafted tokens | ||||
|         llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); | ||||
|         ++n_past_tgt; | ||||
|  | ||||
|         drafted.erase(drafted.begin()); | ||||
|     } | ||||
|  | ||||
|     auto t_gen_end = ggml_time_us(); | ||||
|  | ||||
|     LOG_TEE("\n\n"); | ||||
|  | ||||
|     LOG_TEE("generated %d tokens in %.3f seconds, speed: %.3f t/s\n", n_predict, (t_gen_end - t_gen_start) / 1e6f, n_predict / ((t_gen_end - t_gen_start) / 1e6f)); | ||||
|  | ||||
|     // TODO: make sure these numbers are computed correctly | ||||
|     LOG_TEE("\n"); | ||||
|     LOG_TEE("n_draft   = %d\n", n_draft); | ||||
|     LOG_TEE("n_predict = %d\n", n_predict); | ||||
|     LOG_TEE("n_drafted = %d\n", n_drafted); | ||||
|     LOG_TEE("n_accept  = %d\n", n_accept); | ||||
|     LOG_TEE("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted); | ||||
|  | ||||
|     LOG_TEE("\ndraft:\n"); | ||||
|     llama_print_timings(ctx_dft); | ||||
|  | ||||
|     LOG_TEE("\ntarget:\n"); | ||||
|     llama_print_timings(ctx_tgt); | ||||
|  | ||||
|     llama_free(ctx_tgt); | ||||
|     llama_free_model(model_tgt); | ||||
|  | ||||
|     llama_free(ctx_dft); | ||||
|     llama_free_model(model_dft); | ||||
|  | ||||
|     llama_backend_free(); | ||||
|  | ||||
|     fprintf(stderr, "\n\n"); | ||||
|  | ||||
|     return 0; | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov