From 4317d5abf5b554aba4c677441a851111b055b20c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Aug 2025 13:55:21 +0300 Subject: [PATCH] wip --- src/llama-context.cpp | 6 ------ src/llama-graph.cpp | 19 +++++++++++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b20161a38..0972d33b18 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1338,10 +1338,6 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); - // store the n_outputs as it is, and restore it afterwards - // TODO: not sure if needed, might simplify in the future by removing this - const auto save_n_outputs = this->n_outputs; - this->n_outputs = n_outputs; llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); @@ -1355,8 +1351,6 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * gf = model.build_graph(gparams); - this->n_outputs = save_n_outputs; - // initialize scheduler with the specified graph if (!ggml_backend_sched_reserve(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1f2fc3ab62..ba56f0c205 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1339,8 +1339,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { auto inp = std::make_unique(hparams, cparams); - // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = ubatch.n_seqs_unq; + + // note: there is no KV cache, so the mask is square with size n_tokens/n_stream + inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens/n_stream, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->kq_mask); inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; @@ -1370,14 +1373,18 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); - // [TAG_NO_CACHE_PAD] - // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq)); - ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; + if (ubatch.equal_seqs()) { + GGML_ASSERT(k_cur->ne[2] % ubatch.n_seqs_unq == 0); + GGML_ASSERT(k_cur->ne[3] == 1); + + k = ggml_reshape_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2]/ubatch.n_seqs_unq, ubatch.n_seqs_unq); + v = ggml_reshape_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2]/ubatch.n_seqs_unq, ubatch.n_seqs_unq); + } + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale); cb(cur, "kqv_out", il);