diff --git a/common/common.h b/common/common.h index 78c568a7bc..54b7849b17 100644 --- a/common/common.h +++ b/common/common.h @@ -507,6 +507,10 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; + + bool has_speculative() const { + return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); + } }; // call once at the start of a program if it uses libcommon diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 678aad93b8..f5089bef24 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2400,7 +2400,7 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); - if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + if (params_base.has_speculative()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); auto params_dft = params_base; @@ -2476,7 +2476,7 @@ struct server_context { SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); } - if (!params_base.speculative.model.path.empty()) { + if (params_base.has_speculative()) { SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); return false; } @@ -2520,6 +2520,7 @@ struct server_context { if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { SRV_ERR("%s", "failed to create draft context\n"); @@ -2825,6 +2826,7 @@ struct server_context { } // initialize draft batch + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] if (slot.ctx_dft) { llama_batch_free(slot.batch_spec); @@ -4291,6 +4293,8 @@ struct server_context { } // do speculative decoding + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { continue; @@ -4445,8 +4449,10 @@ int main(int argc, char ** argv) { // TODO: should we have a separate n_parallel parameter for the server? // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - if (params.n_parallel == 1 && params.kv_unified == false) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__); + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); params.n_parallel = 4; params.kv_unified = true;