llama : fix fattn reserve call n_seqs parameter (#15699)

ggml-ci
This commit is contained in:
Diego Devesa
2025-08-31 08:47:05 -07:00
committed by GitHub
parent 9777032dcc
commit 274966226f

View File

@@ -281,9 +281,15 @@ llama_context::llama_context(
} }
cross.v_embd.clear(); cross.v_embd.clear();
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
// resolve automatic Flash Attention use // resolve automatic Flash Attention use
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
auto * gf = graph_reserve(1, 1, 0, mctx.get(), true); auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) { if (!gf) {
throw std::runtime_error("failed to split graph for Flash Attention check"); throw std::runtime_error("failed to split graph for Flash Attention check");
} }
@@ -324,11 +330,6 @@ llama_context::llama_context(
} }
// reserve worst-case graph // reserve worst-case graph
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
int n_splits_pp = -1; int n_splits_pp = -1;
int n_nodes_pp = -1; int n_nodes_pp = -1;