cont : fix speculative decoding initialization

This commit is contained in:
Georgi Gerganov
2025-10-23 14:33:53 +03:00
parent 4dcf0a6d02
commit 2ec7cda706

View File

@@ -2379,6 +2379,10 @@ struct server_context {
llama_batch_free(batch); llama_batch_free(batch);
} }
int32_t n_ctx_slot() const {
return params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel;
}
bool load_model(const common_params & params) { bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.path.c_str()); SRV_INF("loading model '%s'\n", params.model.path.c_str());
@@ -2407,7 +2411,7 @@ struct server_context {
params_dft.devices = params_base.speculative.devices; params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model; params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? slots.front().n_ctx : params_base.speculative.n_ctx; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? n_ctx_slot() : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1; params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k; params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2495,8 +2499,6 @@ struct server_context {
} }
void init() { void init() {
const int32_t n_ctx_slot = params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel;
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
for (int i = 0; i < params_base.n_parallel; i++) { for (int i = 0; i < params_base.n_parallel; i++) {
@@ -2504,7 +2506,7 @@ struct server_context {
slot.id = i; slot.id = i;
slot.ctx = ctx; slot.ctx = ctx;
slot.n_ctx = n_ctx_slot; slot.n_ctx = n_ctx_slot();
slot.mctx = mctx; slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr; slot.prompt.tokens.has_mtmd = mctx != nullptr;
@@ -2527,7 +2529,7 @@ struct server_context {
} }
} }
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int) { slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task(); queue_tasks.pop_deferred_task();