This commit is contained in:
Georgi Gerganov
2025-08-28 13:55:21 +03:00
parent 8a4280ce43
commit 4317d5abf5
2 changed files with 13 additions and 12 deletions

View File

@@ -1338,10 +1338,6 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
// when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
gf_res_prev->reset();
// store the n_outputs as it is, and restore it afterwards
// TODO: not sure if needed, might simplify in the future by removing this
const auto save_n_outputs = this->n_outputs;
this->n_outputs = n_outputs;
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
@@ -1355,8 +1351,6 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
auto * gf = model.build_graph(gparams);
this->n_outputs = save_n_outputs;
// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);

View File

@@ -1339,8 +1339,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = ubatch.n_seqs_unq;
// note: there is no KV cache, so the mask is square with size n_tokens/n_stream
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens/n_stream, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1370,14 +1373,18 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
if (ubatch.equal_seqs()) {
GGML_ASSERT(k_cur->ne[2] % ubatch.n_seqs_unq == 0);
GGML_ASSERT(k_cur->ne[3] == 1);
k = ggml_reshape_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2]/ubatch.n_seqs_unq, ubatch.n_seqs_unq);
v = ggml_reshape_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2]/ubatch.n_seqs_unq, ubatch.n_seqs_unq);
}
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
cb(cur, "kqv_out", il);