server : do not default to multiple slots with speculative decoding (#17017)

* server : do not default to multiple slots with speculative decoding

* cont : fix
This commit is contained in:
Georgi Gerganov
2025-11-05 14:32:55 +02:00
committed by GitHub
parent 2f0c2db43e
commit 13b339bcd9
2 changed files with 14 additions and 4 deletions

View File

@@ -507,6 +507,10 @@ struct common_params {
// return false from callback to abort model loading or true to continue // return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL; llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL; void * load_progress_callback_user_data = NULL;
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}
}; };
// call once at the start of a program if it uses libcommon // call once at the start of a program if it uses libcommon

View File

@@ -2400,7 +2400,7 @@ struct server_context {
add_bos_token = llama_vocab_get_add_bos(vocab); add_bos_token = llama_vocab_get_add_bos(vocab);
if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { if (params_base.has_speculative()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
auto params_dft = params_base; auto params_dft = params_base;
@@ -2476,7 +2476,7 @@ struct server_context {
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
} }
if (!params_base.speculative.model.path.empty()) { if (params_base.has_speculative()) {
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
return false; return false;
} }
@@ -2520,6 +2520,7 @@ struct server_context {
if (model_dft) { if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) { if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n"); SRV_ERR("%s", "failed to create draft context\n");
@@ -2825,6 +2826,7 @@ struct server_context {
} }
// initialize draft batch // initialize draft batch
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
if (slot.ctx_dft) { if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec); llama_batch_free(slot.batch_spec);
@@ -4291,6 +4293,8 @@ struct server_context {
} }
// do speculative decoding // do speculative decoding
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
for (auto & slot : slots) { for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) { if (!slot.is_processing() || !slot.can_speculate()) {
continue; continue;
@@ -4445,8 +4449,10 @@ int main(int argc, char ** argv) {
// TODO: should we have a separate n_parallel parameter for the server? // TODO: should we have a separate n_parallel parameter for the server?
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
if (params.n_parallel == 1 && params.kv_unified == false) { // TODO: this is a common configuration that is suitable for most local use cases
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__); // however, overriding the parameters is a bit confusing - figure out something more intuitive
if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) {
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__);
params.n_parallel = 4; params.n_parallel = 4;
params.kv_unified = true; params.kv_unified = true;