From 492f628c586ebc1932a30fb9823cc0f788f6eefc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 23 Oct 2025 14:51:26 +0300 Subject: [PATCH] context : fix n_ctx_per_seq computation --- src/llama-context.cpp | 14 ++++++-------- src/llama-model.cpp | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f6192a36e0..949d157c86 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -112,11 +112,9 @@ llama_context::llama_context( } } - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; - LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq()); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); @@ -125,14 +123,14 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); - if (n_ctx_per_seq < hparams.n_ctx_train) { + if (n_ctx_per_seq() < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); + __func__, n_ctx_per_seq(), hparams.n_ctx_train); } - if (n_ctx_per_seq > hparams.n_ctx_train) { + if (n_ctx_per_seq() > hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); + __func__, n_ctx_per_seq(), hparams.n_ctx_train); } if (!hparams.vocab_only) { @@ -454,7 +452,7 @@ uint32_t llama_context::n_ctx() const { } uint32_t llama_context::n_ctx_per_seq() const { - return cparams.n_ctx / cparams.n_seq_max; + return cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max; } uint32_t llama_context::n_batch() const { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 04239181c7..36dcdd33ed 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6712,7 +6712,7 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co } ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + const uint32_t n_ctx_per_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max; // choose long/short freq factors based on the context size if (layers[il].rope_freqs != nullptr) {