kv-cache : use ggml_set_rows

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-19 19:26:47 +03:00
parent 1f647b5992
commit 79dac3c861
4 changed files with 89 additions and 18 deletions

View File

@@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kv_idxs) {
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
}
if (self_kq_mask) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kv_idxs) {
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
}
if (self_kv_idxs_swa) {
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
}
if (self_kq_mask) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
@@ -1192,6 +1204,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
const auto n_kv = mctx_cur->get_n_kv();
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
ggml_set_input(inp->self_kv_idxs);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
@@ -1224,8 +1239,10 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & kv_idxs = inp->get_kv_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
@@ -1278,8 +1295,10 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
}
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1383,8 +1402,8 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
}
const auto & kq_mask = inp->get_kq_mask();
@@ -1419,6 +1438,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
ggml_set_input(inp->self_kv_idxs);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
@@ -1431,6 +1453,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
ggml_set_input(inp->self_kv_idxs_swa);
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);