mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
		| @@ -470,8 +470,6 @@ node index.js | ||||
|  | ||||
|     `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. | ||||
|  | ||||
|     `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`. | ||||
|  | ||||
|     `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. | ||||
|  | ||||
|     `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` | ||||
| @@ -724,7 +722,6 @@ Example: | ||||
|             "stopping_word": "" | ||||
|         }, | ||||
|         "penalize_nl": true, | ||||
|         "penalty_prompt_tokens": [], | ||||
|         "presence_penalty": 0.0, | ||||
|         "prompt": "Say hello to llama.cpp", | ||||
|         "repeat_last_n": 64, | ||||
| @@ -748,8 +745,7 @@ Example: | ||||
|         "tfs_z": 1.0, | ||||
|         "top_k": 40, | ||||
|         "top_p": 0.949999988079071, | ||||
|         "typical_p": 1.0, | ||||
|         "use_penalty_prompt_tokens": false | ||||
|         "typical_p": 1.0 | ||||
|     } | ||||
| ] | ||||
| ``` | ||||
|   | ||||
| @@ -3,7 +3,6 @@ | ||||
| #include "common.h" | ||||
| #include "json-schema-to-grammar.h" | ||||
| #include "llama.h" | ||||
| #include "grammar-parser.h" | ||||
|  | ||||
| // Change JSON_ASSERT from assert() to GGML_ASSERT: | ||||
| #define JSON_ASSERT GGML_ASSERT | ||||
| @@ -173,11 +172,13 @@ struct server_slot { | ||||
|     std::string stopping_word; | ||||
|  | ||||
|     // sampling | ||||
|     llama_token sampled; | ||||
|     struct llama_sampling_params sparams; | ||||
|     llama_sampling_context * ctx_sampling = nullptr; | ||||
|     json json_schema; | ||||
|  | ||||
|     struct gpt_sampling_params sparams; | ||||
|  | ||||
|     llama_token sampled; | ||||
|     llama_sampling * smpl = nullptr; | ||||
|  | ||||
|     int32_t ga_i = 0;   // group-attention state | ||||
|     int32_t ga_n = 1;   // group-attention factor | ||||
|     int32_t ga_w = 512; // group-attention width | ||||
| @@ -636,8 +637,8 @@ struct server_context { | ||||
|  | ||||
|         // Clear any sampling context | ||||
|         for (server_slot & slot : slots) { | ||||
|             if (slot.ctx_sampling != nullptr) { | ||||
|                 llama_sampling_free(slot.ctx_sampling); | ||||
|             if (slot.smpl != nullptr) { | ||||
|                 llama_sampling_free(slot.smpl); | ||||
|             } | ||||
|         } | ||||
|  | ||||
| @@ -864,8 +865,8 @@ struct server_context { | ||||
|     bool launch_slot_with_task(server_slot & slot, const server_task & task) { | ||||
|         slot_params default_params; | ||||
|         // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) | ||||
|         llama_sampling_params default_sparams = params.sparams; | ||||
|         auto & data = task.data; | ||||
|         auto default_sparams = params.sparams; | ||||
|         const auto & data = task.data; | ||||
|  | ||||
|         if (data.count("__oaicompat") != 0) { | ||||
|             slot.oaicompat = true; | ||||
| @@ -882,7 +883,7 @@ struct server_context { | ||||
|         slot.sparams.top_p             = json_value(data, "top_p",             default_sparams.top_p); | ||||
|         slot.sparams.min_p             = json_value(data, "min_p",             default_sparams.min_p); | ||||
|         slot.sparams.tfs_z             = json_value(data, "tfs_z",             default_sparams.tfs_z); | ||||
|         slot.sparams.typical_p         = json_value(data, "typical_p",         default_sparams.typical_p); | ||||
|         slot.sparams.typ_p             = json_value(data, "typical_p",         default_sparams.typ_p); | ||||
|         slot.sparams.temp              = json_value(data, "temperature",       default_sparams.temp); | ||||
|         slot.sparams.dynatemp_range    = json_value(data, "dynatemp_range",    default_sparams.dynatemp_range); | ||||
|         slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); | ||||
| @@ -904,7 +905,8 @@ struct server_context { | ||||
|         if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { | ||||
|             send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); | ||||
|             return false; | ||||
|         } else if (data.contains("json_schema") && !data.contains("grammar")) { | ||||
|         } | ||||
|         if (data.contains("json_schema") && !data.contains("grammar")) { | ||||
|             try { | ||||
|                 auto schema                = json_value(data, "json_schema", json::object()); | ||||
|                 slot.sparams.grammar       = json_schema_to_grammar(schema); | ||||
| @@ -954,56 +956,11 @@ struct server_context { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // penalize user-provided tokens | ||||
|         { | ||||
|             slot.sparams.penalty_prompt_tokens.clear(); | ||||
|             slot.sparams.use_penalty_prompt_tokens = false; | ||||
|  | ||||
|             const auto & penalty_prompt = data.find("penalty_prompt"); | ||||
|  | ||||
|             if (penalty_prompt != data.end()) { | ||||
|                 if (penalty_prompt->is_string()) { | ||||
|                     const auto penalty_prompt_string = penalty_prompt->get<std::string>(); | ||||
|                     slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); | ||||
|  | ||||
|                     if (slot.params.n_predict > 0) { | ||||
|                         slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict); | ||||
|                     } | ||||
|                     slot.sparams.use_penalty_prompt_tokens = true; | ||||
|  | ||||
|                     LOG_VERBOSE("penalty_prompt_tokens", { | ||||
|                         {"id_slot", slot.id}, | ||||
|                         {"tokens",  slot.sparams.penalty_prompt_tokens}, | ||||
|                     }); | ||||
|                 } | ||||
|                 else if (penalty_prompt->is_array()) { | ||||
|                     const auto n_tokens = penalty_prompt->size(); | ||||
|                     slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); | ||||
|  | ||||
|                     const int n_vocab = llama_n_vocab(model); | ||||
|                     for (const auto & penalty_token : *penalty_prompt) { | ||||
|                         if (penalty_token.is_number_integer()) { | ||||
|                             const auto tok = penalty_token.get<llama_token>(); | ||||
|                             if (tok >= 0 && tok < n_vocab) { | ||||
|                                 slot.sparams.penalty_prompt_tokens.push_back(tok); | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                     slot.sparams.use_penalty_prompt_tokens = true; | ||||
|  | ||||
|                     LOG_VERBOSE("penalty_prompt_tokens", { | ||||
|                         {"id_slot", slot.id}, | ||||
|                         {"tokens",  slot.sparams.penalty_prompt_tokens}, | ||||
|                     }); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         { | ||||
|             slot.sparams.logit_bias.clear(); | ||||
|  | ||||
|             if (json_value(data, "ignore_eos", false) && has_eos_token) { | ||||
|                 slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; | ||||
|                 slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); | ||||
|             } | ||||
|  | ||||
|             const auto & logit_bias = data.find("logit_bias"); | ||||
| @@ -1024,12 +981,12 @@ struct server_context { | ||||
|                         if (el[0].is_number_integer()) { | ||||
|                             llama_token tok = el[0].get<llama_token>(); | ||||
|                             if (tok >= 0 && tok < n_vocab) { | ||||
|                                 slot.sparams.logit_bias[tok] = bias; | ||||
|                                 slot.sparams.logit_bias.push_back({tok, bias}); | ||||
|                             } | ||||
|                         } else if (el[0].is_string()) { | ||||
|                             auto toks = llama_tokenize(model, el[0].get<std::string>(), false); | ||||
|                             for (auto tok : toks) { | ||||
|                                 slot.sparams.logit_bias[tok] = bias; | ||||
|                                 slot.sparams.logit_bias.push_back({tok, bias}); | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
| @@ -1051,26 +1008,27 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         { | ||||
|             const auto & samplers_sequence = data.find("samplers"); | ||||
|             if (samplers_sequence != data.end() && samplers_sequence->is_array()) { | ||||
|             const auto & samplers = data.find("samplers"); | ||||
|             if (samplers != data.end() && samplers->is_array()) { | ||||
|                 std::vector<std::string> sampler_names; | ||||
|                 for (const auto & sampler_name : *samplers_sequence) { | ||||
|                 for (const auto & sampler_name : *samplers) { | ||||
|                     if (sampler_name.is_string()) { | ||||
|                         sampler_names.emplace_back(sampler_name); | ||||
|                     } | ||||
|                 } | ||||
|                 slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false); | ||||
|                 slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false); | ||||
|             } else { | ||||
|                 slot.sparams.samplers_sequence = default_sparams.samplers_sequence; | ||||
|                 slot.sparams.samplers = default_sparams.samplers; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         { | ||||
|             if (slot.ctx_sampling != nullptr) { | ||||
|                 llama_sampling_free(slot.ctx_sampling); | ||||
|             if (slot.smpl != nullptr) { | ||||
|                 llama_sampling_free(slot.smpl); | ||||
|             } | ||||
|             slot.ctx_sampling = llama_sampling_init(slot.sparams); | ||||
|             if (slot.ctx_sampling == nullptr) { | ||||
|  | ||||
|             slot.smpl = llama_sampling_init(model, slot.sparams); | ||||
|             if (slot.smpl == nullptr) { | ||||
|                 // for now, the only error that may happen here is invalid grammar | ||||
|                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); | ||||
|                 return false; | ||||
| @@ -1159,11 +1117,6 @@ struct server_context { | ||||
|         slot.generated_text += token_str; | ||||
|         slot.has_next_token = true; | ||||
|  | ||||
|         if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { | ||||
|             // we can change penalty_prompt_tokens because it is always created from scratch each request | ||||
|             slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); | ||||
|         } | ||||
|  | ||||
|         // check if there is incomplete UTF-8 character at the end | ||||
|         bool incomplete = false; | ||||
|         for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { | ||||
| @@ -1281,13 +1234,10 @@ struct server_context { | ||||
|     } | ||||
|  | ||||
|     json get_formated_generation(const server_slot & slot) const { | ||||
|         const auto eos_bias   =             slot.sparams.logit_bias.find(llama_token_eos(model)); | ||||
|         const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); | ||||
|  | ||||
|         std::vector<std::string> samplers_sequence; | ||||
|         samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); | ||||
|         for (const auto & sampler_type : slot.sparams.samplers_sequence) { | ||||
|             samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); | ||||
|         std::vector<std::string> samplers; | ||||
|         samplers.reserve(slot.sparams.samplers.size()); | ||||
|         for (const auto & sampler : slot.sparams.samplers) { | ||||
|             samplers.emplace_back(llama_sampling_type_to_str(sampler)); | ||||
|         } | ||||
|  | ||||
|         return json { | ||||
| @@ -1302,13 +1252,11 @@ struct server_context { | ||||
|             {"top_p",                     slot.sparams.top_p}, | ||||
|             {"min_p",                     slot.sparams.min_p}, | ||||
|             {"tfs_z",                     slot.sparams.tfs_z}, | ||||
|             {"typical_p",                 slot.sparams.typical_p}, | ||||
|             {"typical_p",                 slot.sparams.typ_p}, | ||||
|             {"repeat_last_n",             slot.sparams.penalty_last_n}, | ||||
|             {"repeat_penalty",            slot.sparams.penalty_repeat}, | ||||
|             {"presence_penalty",          slot.sparams.penalty_present}, | ||||
|             {"frequency_penalty",         slot.sparams.penalty_freq}, | ||||
|             {"penalty_prompt_tokens",     slot.sparams.penalty_prompt_tokens}, | ||||
|             {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, | ||||
|             {"mirostat",                  slot.sparams.mirostat}, | ||||
|             {"mirostat_tau",              slot.sparams.mirostat_tau}, | ||||
|             {"mirostat_eta",              slot.sparams.mirostat_eta}, | ||||
| @@ -1317,13 +1265,13 @@ struct server_context { | ||||
|             {"max_tokens",                slot.params.n_predict}, // User configured n_predict | ||||
|             {"n_keep",                    slot.params.n_keep}, | ||||
|             {"n_discard",                 slot.params.n_discard}, | ||||
|             {"ignore_eos",                ignore_eos}, | ||||
|             {"ignore_eos",                slot.sparams.ignore_eos}, | ||||
|             {"stream",                    slot.params.stream}, | ||||
|             {"logit_bias",                slot.sparams.logit_bias}, | ||||
|           //{"logit_bias",                slot.sparams.logit_bias}, | ||||
|             {"n_probs",                   slot.sparams.n_probs}, | ||||
|             {"min_keep",                  slot.sparams.min_keep}, | ||||
|             {"grammar",                   slot.sparams.grammar}, | ||||
|             {"samplers",                  samplers_sequence} | ||||
|             {"samplers",                  samplers}, | ||||
|         }; | ||||
|     } | ||||
|  | ||||
| @@ -2139,7 +2087,7 @@ struct server_context { | ||||
|                                 GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); | ||||
|                             } | ||||
|  | ||||
|                             llama_sampling_reset(slot.ctx_sampling); | ||||
|                             llama_sampling_reset(slot.smpl); | ||||
|  | ||||
|                             if (!slot.params.cache_prompt) { | ||||
|                                 slot.n_past_se = 0; | ||||
| @@ -2152,7 +2100,7 @@ struct server_context { | ||||
|  | ||||
|                                 // push the prompt into the sampling context (do not apply grammar) | ||||
|                                 for (int i = 0; i < slot.n_past; ++i) { | ||||
|                                     llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); | ||||
|                                     llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false); | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
| @@ -2205,7 +2153,7 @@ struct server_context { | ||||
|                         slot.n_past_se = 0; | ||||
|                         slot.ga_i = 0; | ||||
|                         // TODO: is the system prompt ever in the sampling context? | ||||
|                         llama_sampling_reset(slot.ctx_sampling); | ||||
|                         llama_sampling_reset(slot.smpl); | ||||
|                     } | ||||
|  | ||||
|                     // remove the non-common part from the cache | ||||
| @@ -2382,9 +2330,9 @@ struct server_context { | ||||
|                 } | ||||
|  | ||||
|                 completion_token_output result; | ||||
|                 const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); | ||||
|                 const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i); | ||||
|  | ||||
|                 llama_sampling_accept(slot.ctx_sampling, ctx, id, true); | ||||
|                 llama_sampling_accept(slot.smpl, id, true); | ||||
|  | ||||
|                 slot.n_decoded += 1; | ||||
|                 if (slot.n_decoded == 1) { | ||||
| @@ -2393,34 +2341,17 @@ struct server_context { | ||||
|                     metrics.on_prompt_eval(slot); | ||||
|                 } | ||||
|  | ||||
|                 llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; | ||||
|                 result.tok = id; | ||||
|  | ||||
|                 const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); | ||||
|                 if (n_probs > 0) { | ||||
|                     const size_t n_valid = slot.ctx_sampling->n_valid; | ||||
|                 const auto * cur_p = llama_sampling_get_candidates(slot.smpl); | ||||
|  | ||||
|                     // Make sure at least n_probs top tokens are at the front of the vector: | ||||
|                     if (slot.sparams.temp == 0.0f && n_probs > n_valid) { | ||||
|                         llama_sample_top_k(ctx, &cur_p, n_probs, 0); | ||||
|                     } | ||||
|  | ||||
|                     if (slot.sparams.temp == 0.0f) { | ||||
|                         // With greedy sampling the probabilities have possibly not been calculated. | ||||
|                         for (size_t i = 0; i < n_probs; ++i) { | ||||
|                             result.probs.push_back({ | ||||
|                                 cur_p.data[i].id, | ||||
|                                 i == 0 ? 1.0f : 0.0f | ||||
|                             }); | ||||
|                         } | ||||
|                     } else { | ||||
|                         for (size_t i = 0; i < n_probs; ++i) { | ||||
|                             result.probs.push_back({ | ||||
|                                 cur_p.data[i].id, | ||||
|                                 i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. | ||||
|                             }); | ||||
|                         } | ||||
|                     } | ||||
|                 // TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643 | ||||
|                 //       fix if necessary | ||||
|                 for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { | ||||
|                     result.probs.push_back({ | ||||
|                         cur_p->data[i].id, | ||||
|                         i >= cur_p->size ? 0.0f : cur_p->data[i].p, | ||||
|                     }); | ||||
|                 } | ||||
|  | ||||
|                 if (!process_token(result, slot)) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov