Merge branch 'master' into compilade/refactor-kv-cache

This commit is contained in:
Francis Couture-Harpin
2025-07-03 16:03:56 -04:00
89 changed files with 895 additions and 4844 deletions

View File

@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -332,7 +335,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_one::set_input(const llama_ubatch *) {
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -1155,8 +1159,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
const auto n_kv = mctx_cur->get_n_kv();
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1187,8 +1193,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
@@ -1247,11 +1256,15 @@ ggml_tensor * llm_graph_context::build_attn(
// optionally store to KV cache
if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}
if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1343,8 +1356,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1355,8 +1370,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;