mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : new sampling algorithms (#1126)
* Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. * mirostat * Added --logit-bias and --no-penalize-nl, removed std::span * Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) * Save and load example adjust * Tests * Windows build fix * Windows test fix
This commit is contained in:
		| @@ -6,6 +6,8 @@ | ||||
| #include <string> | ||||
| #include <iterator> | ||||
| #include <algorithm> | ||||
| #include <sstream> | ||||
| #include <iostream> | ||||
|  | ||||
| #if defined (_WIN32) | ||||
| #include <fcntl.h> | ||||
| @@ -114,6 +116,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||
|                 break; | ||||
|             } | ||||
|             params.temp = std::stof(argv[i]); | ||||
|         } else if (arg == "--tfs") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.tfs_z = std::stof(argv[i]); | ||||
|         } else if (arg == "--typical") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.typical_p = std::stof(argv[i]); | ||||
|         } else if (arg == "--repeat_last_n") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
| @@ -126,6 +140,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||
|                 break; | ||||
|             } | ||||
|             params.repeat_penalty = std::stof(argv[i]); | ||||
|         } else if (arg == "--frequency_penalty") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.frequency_penalty = std::stof(argv[i]); | ||||
|         } else if (arg == "--presence_penalty") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.presence_penalty = std::stof(argv[i]); | ||||
|         } else if (arg == "--mirostat") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.mirostat = std::stoi(argv[i]); | ||||
|         } else if (arg == "--mirostat_lr") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.mirostat_eta = std::stof(argv[i]); | ||||
|         } else if (arg == "--mirostat_ent") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.mirostat_tau = std::stof(argv[i]); | ||||
|         } else if (arg == "-b" || arg == "--batch_size") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
| @@ -185,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||
|         } else if (arg == "--perplexity") { | ||||
|             params.perplexity = true; | ||||
|         } else if (arg == "--ignore-eos") { | ||||
|             params.ignore_eos = true; | ||||
|             params.logit_bias[llama_token_eos()] = -INFINITY; | ||||
|         } else if (arg == "--no-penalize-nl") { | ||||
|             params.penalize_nl = false; | ||||
|         } else if (arg == "-l" || arg == "--logit-bias") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             std::stringstream ss(argv[i]); | ||||
|             llama_token key; | ||||
|             char sign; | ||||
|             std::string value_str; | ||||
|             try { | ||||
|                 if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { | ||||
|                     params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); | ||||
|                 } else { | ||||
|                     throw std::exception(); | ||||
|                 } | ||||
|             } catch (const std::exception &e) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|         } else if (arg == "--n_parts") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
| @@ -240,12 +305,26 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | ||||
|     fprintf(stderr, "  -f FNAME, --file FNAME\n"); | ||||
|     fprintf(stderr, "                        prompt file to start generation.\n"); | ||||
|     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); | ||||
|     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k); | ||||
|     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", (double)params.top_p); | ||||
|     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); | ||||
|     fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); | ||||
|     fprintf(stderr, "  --top_k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k); | ||||
|     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); | ||||
|     fprintf(stderr, "  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); | ||||
|     fprintf(stderr, "  --typical N           locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); | ||||
|     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); | ||||
|     fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); | ||||
|     fprintf(stderr, "  --presence_penalty N  repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); | ||||
|     fprintf(stderr, "  --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); | ||||
|     fprintf(stderr, "  --mirostat N          use Mirostat sampling.\n"); | ||||
|     fprintf(stderr, "                        Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); | ||||
|     fprintf(stderr, "                        (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); | ||||
|     fprintf(stderr, "  --mirostat_lr N       Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); | ||||
|     fprintf(stderr, "  --mirostat_ent N      Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); | ||||
|     fprintf(stderr, "  -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); | ||||
|     fprintf(stderr, "                        modifies the likelihood of token appearing in the completion,\n"); | ||||
|     fprintf(stderr, "                        i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); | ||||
|     fprintf(stderr, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); | ||||
|     fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx); | ||||
|     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n"); | ||||
|     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); | ||||
|     fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n"); | ||||
|     fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n"); | ||||
|     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", (double)params.temp); | ||||
|     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n"); | ||||
|   | ||||
| @@ -8,6 +8,7 @@ | ||||
| #include <vector> | ||||
| #include <random> | ||||
| #include <thread> | ||||
| #include <unordered_map> | ||||
|  | ||||
| // | ||||
| // CLI argument parsing | ||||
| @@ -17,17 +18,25 @@ struct gpt_params { | ||||
|     int32_t seed          = -1;   // RNG seed | ||||
|     int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency()); | ||||
|     int32_t n_predict     = 128;  // new tokens to predict | ||||
|     int32_t repeat_last_n = 64;   // last n tokens to penalize | ||||
|     int32_t n_parts       = -1;   // amount of model parts (-1 = determine from model dimensions) | ||||
|     int32_t n_ctx         = 512;  // context size | ||||
|     int32_t n_batch       = 512;  // batch size for prompt processing (must be >=32 to use BLAS) | ||||
|     int32_t n_keep        = 0;    // number of tokens to keep from initial prompt | ||||
|  | ||||
|     // sampling parameters | ||||
|     int32_t top_k = 40; | ||||
|     float   top_p = 0.95f; | ||||
|     float   temp  = 0.80f; | ||||
|     float   repeat_penalty  = 1.10f; | ||||
|     std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens | ||||
|     int32_t top_k = 0;              // <= 0 to use vocab size | ||||
|     float   top_p = 1.0f;           // 1.0 = disabled | ||||
|     float   tfs_z = 1.0f;           // 1.0 = disabled | ||||
|     float   typical_p = 1.0f;       // 1.0 = disabled | ||||
|     float   temp = 1.0f;            // 1.0 = disabled | ||||
|     float   repeat_penalty  = 1.0f; // 1.0 = disabled | ||||
|     int32_t repeat_last_n = -1;     // last n tokens to penalize (0 = disable penalty, -1 = context size) | ||||
|     float   frequency_penalty = 0.0f; // 0.0 = disabled | ||||
|     float   presence_penalty = 0.0f;  // 0.0 = disabled | ||||
|     int     mirostat = 0;           // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 | ||||
|     float   mirostat_tau = 5.0f;    // target entropy | ||||
|     float   mirostat_eta = 0.1f;    // learning rate | ||||
|  | ||||
|     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path | ||||
|     std::string prompt = ""; | ||||
| @@ -47,7 +56,7 @@ struct gpt_params { | ||||
|     bool interactive_first = false; // wait for user input immediately | ||||
|  | ||||
|     bool instruct          = false; // instruction mode (used for Alpaca models) | ||||
|     bool ignore_eos        = false; // do not stop generating after eos | ||||
|     bool penalize_nl       = true;  // consider newlines as a repeatable token | ||||
|     bool perplexity        = false; // compute perplexity over the prompt | ||||
|     bool use_mmap          = true;  // use mmap for faster loads | ||||
|     bool use_mlock         = false; // use mlock to keep model in memory | ||||
|   | ||||
| @@ -276,8 +276,8 @@ int main(int argc, char ** argv) { | ||||
|             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); | ||||
|         } | ||||
|     } | ||||
|     fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", | ||||
|         params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); | ||||
|     fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", | ||||
|             params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); | ||||
|     fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); | ||||
|     fprintf(stderr, "\n\n"); | ||||
|  | ||||
| @@ -387,10 +387,19 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||||
|             // out of user input, sample next token | ||||
|             const int32_t top_k          = params.top_k; | ||||
|             const float   top_p          = params.top_p; | ||||
|             const float   temp           = params.temp; | ||||
|             const int32_t top_k          = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; | ||||
|             const float   top_p          = params.top_p; | ||||
|             const float   tfs_z          = params.tfs_z; | ||||
|             const float   typical_p      = params.typical_p; | ||||
|             const int32_t repeat_last_n  = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; | ||||
|             const float   repeat_penalty = params.repeat_penalty; | ||||
|             const float   alpha_presence = params.presence_penalty; | ||||
|             const float   alpha_frequency = params.frequency_penalty; | ||||
|             const int     mirostat       = params.mirostat; | ||||
|             const float   mirostat_tau   = params.mirostat_tau; | ||||
|             const float   mirostat_eta   = params.mirostat_eta; | ||||
|             const bool    penalize_nl   = params.penalize_nl; | ||||
|  | ||||
|             // optionally save the session on first sample (for faster prompt loading next time) | ||||
|             if (!path_session.empty() && need_to_save_session) { | ||||
| @@ -402,14 +411,58 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|             { | ||||
|                 auto logits = llama_get_logits(ctx); | ||||
|                 auto n_vocab = llama_n_vocab(ctx); | ||||
|  | ||||
|                 if (params.ignore_eos) { | ||||
|                     logits[llama_token_eos()] = 0; | ||||
|                 // Apply params.logit_bias map | ||||
|                 for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { | ||||
|                     logits[it->first] += it->second; | ||||
|                 } | ||||
|  | ||||
|                 id = llama_sample_top_p_top_k(ctx, | ||||
|                         last_n_tokens.data() + n_ctx - params.repeat_last_n, | ||||
|                         params.repeat_last_n, top_k, top_p, temp, repeat_penalty); | ||||
|                 std::vector<llama_token_data> candidates; | ||||
|                 candidates.reserve(n_vocab); | ||||
|                 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|                     candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|                 } | ||||
|  | ||||
|                 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||
|  | ||||
|                 // Apply penalties | ||||
|                 float nl_logit = logits[llama_token_nl()]; | ||||
|                 auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); | ||||
|                 llama_sample_repetition_penalty(ctx, &candidates_p, | ||||
|                     last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|                     last_n_repeat, repeat_penalty); | ||||
|                 llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, | ||||
|                     last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|                     last_n_repeat, alpha_frequency, alpha_presence); | ||||
|                 if (!penalize_nl) { | ||||
|                     logits[llama_token_nl()] = nl_logit; | ||||
|                 } | ||||
|  | ||||
|                 if (temp <= 0) { | ||||
|                     // Greedy sampling | ||||
|                     id = llama_sample_token_greedy(ctx, &candidates_p); | ||||
|                 } else { | ||||
|                     if (mirostat == 1) { | ||||
|                         static float mirostat_mu = 2.0f * mirostat_tau; | ||||
|                         const int mirostat_m = 100; | ||||
|                         llama_sample_temperature(ctx, &candidates_p, temp); | ||||
|                         id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||||
|                     } else if (mirostat == 2) { | ||||
|                         static float mirostat_mu = 2.0f * mirostat_tau; | ||||
|                         llama_sample_temperature(ctx, &candidates_p, temp); | ||||
|                         id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||||
|                     } else { | ||||
|                         // Temperature sampling | ||||
|                         llama_sample_top_k(ctx, &candidates_p, top_k); | ||||
|                         llama_sample_tail_free(ctx, &candidates_p, tfs_z); | ||||
|                         llama_sample_typical(ctx, &candidates_p, typical_p); | ||||
|                         llama_sample_top_p(ctx, &candidates_p, top_p); | ||||
|                         llama_sample_temperature(ctx, &candidates_p, temp); | ||||
|                         id = llama_sample_token(ctx, &candidates_p); | ||||
|                     } | ||||
|                 } | ||||
|                 // printf("`%d`", candidates_p.size); | ||||
|  | ||||
|                 last_n_tokens.erase(last_n_tokens.begin()); | ||||
|                 last_n_tokens.push_back(id); | ||||
|   | ||||
| @@ -64,14 +64,15 @@ int main(int argc, char ** argv) { | ||||
|     // first run | ||||
|     printf("\n%s", params.prompt.c_str()); | ||||
|     for (auto i = 0; i < params.n_predict; i++) { | ||||
|         auto next_token = llama_sample_top_p_top_k( | ||||
|             ctx, | ||||
|             &last_n_tokens_data.back() - params.repeat_last_n, | ||||
|             params.repeat_last_n, | ||||
|             40, | ||||
|             1.0, | ||||
|             1.0, | ||||
|             1.1); | ||||
|         auto logits = llama_get_logits(ctx); | ||||
|         auto n_vocab = llama_n_vocab(ctx); | ||||
|         std::vector<llama_token_data> candidates; | ||||
|         candidates.reserve(n_vocab); | ||||
|         for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|             candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|         } | ||||
|         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||
|         auto next_token = llama_sample_token(ctx, &candidates_p); | ||||
|         auto next_token_str = llama_token_to_str(ctx, next_token); | ||||
|         last_n_tokens_data.push_back(next_token); | ||||
|         printf("%s", next_token_str); | ||||
| @@ -106,14 +107,15 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // second run | ||||
|     for (auto i = 0; i < params.n_predict; i++) { | ||||
|         auto next_token = llama_sample_top_p_top_k( | ||||
|             ctx2, | ||||
|             &last_n_tokens_data.back() - params.repeat_last_n, | ||||
|             params.repeat_last_n, | ||||
|             40, | ||||
|             1.0, | ||||
|             1.0, | ||||
|             1.1); | ||||
|         auto logits = llama_get_logits(ctx2); | ||||
|         auto n_vocab = llama_n_vocab(ctx2); | ||||
|         std::vector<llama_token_data> candidates; | ||||
|         candidates.reserve(n_vocab); | ||||
|         for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|             candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|         } | ||||
|         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||
|         auto next_token = llama_sample_token(ctx2, &candidates_p); | ||||
|         auto next_token_str = llama_token_to_str(ctx2, next_token); | ||||
|         last_n_tokens_data.push_back(next_token); | ||||
|         printf("%s", next_token_str); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Ivan Stepanov
					Ivan Stepanov