mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
server : do not default to multiple slots with speculative decoding (#17017)
* server : do not default to multiple slots with speculative decoding * cont : fix
This commit is contained in:
@@ -507,6 +507,10 @@ struct common_params {
|
|||||||
// return false from callback to abort model loading or true to continue
|
// return false from callback to abort model loading or true to continue
|
||||||
llama_progress_callback load_progress_callback = NULL;
|
llama_progress_callback load_progress_callback = NULL;
|
||||||
void * load_progress_callback_user_data = NULL;
|
void * load_progress_callback_user_data = NULL;
|
||||||
|
|
||||||
|
bool has_speculative() const {
|
||||||
|
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// call once at the start of a program if it uses libcommon
|
// call once at the start of a program if it uses libcommon
|
||||||
|
|||||||
@@ -2400,7 +2400,7 @@ struct server_context {
|
|||||||
|
|
||||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||||
|
|
||||||
if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
|
if (params_base.has_speculative()) {
|
||||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
||||||
|
|
||||||
auto params_dft = params_base;
|
auto params_dft = params_base;
|
||||||
@@ -2476,7 +2476,7 @@ struct server_context {
|
|||||||
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params_base.speculative.model.path.empty()) {
|
if (params_base.has_speculative()) {
|
||||||
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
|
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -2520,6 +2520,7 @@ struct server_context {
|
|||||||
if (model_dft) {
|
if (model_dft) {
|
||||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||||
|
|
||||||
|
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||||
if (slot.ctx_dft == nullptr) {
|
if (slot.ctx_dft == nullptr) {
|
||||||
SRV_ERR("%s", "failed to create draft context\n");
|
SRV_ERR("%s", "failed to create draft context\n");
|
||||||
@@ -2825,6 +2826,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initialize draft batch
|
// initialize draft batch
|
||||||
|
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||||
if (slot.ctx_dft) {
|
if (slot.ctx_dft) {
|
||||||
llama_batch_free(slot.batch_spec);
|
llama_batch_free(slot.batch_spec);
|
||||||
|
|
||||||
@@ -4291,6 +4293,8 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// do speculative decoding
|
// do speculative decoding
|
||||||
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||||
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
if (!slot.is_processing() || !slot.can_speculate()) {
|
if (!slot.is_processing() || !slot.can_speculate()) {
|
||||||
continue;
|
continue;
|
||||||
@@ -4445,8 +4449,10 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// TODO: should we have a separate n_parallel parameter for the server?
|
// TODO: should we have a separate n_parallel parameter for the server?
|
||||||
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
|
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
|
||||||
if (params.n_parallel == 1 && params.kv_unified == false) {
|
// TODO: this is a common configuration that is suitable for most local use cases
|
||||||
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
|
// however, overriding the parameters is a bit confusing - figure out something more intuitive
|
||||||
|
if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) {
|
||||||
|
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__);
|
||||||
|
|
||||||
params.n_parallel = 4;
|
params.n_parallel = 4;
|
||||||
params.kv_unified = true;
|
params.kv_unified = true;
|
||||||
|
|||||||
Reference in New Issue
Block a user