graph : simplify attention api

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-19 18:43:49 +02:00
parent e17e4b72d1
commit 2eacb4c1bf
4 changed files with 47 additions and 75 deletions

View File

@@ -2567,55 +2567,13 @@ void llama_context_kv_self::build_attn_inp(
}
}
void llama_context_kv_self::build_attn_kv_store(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
int32_t n_tokens,
int64_t il,
bool worst_case) {
const auto & hparams = model.hparams;
const auto & n_ctx = cparams.n_ctx;
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
GGML_ASSERT(kv_self.size == n_ctx);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
//cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self.v_l[il]),
(kv_head)*ggml_element_size(kv_self.v_l[il]));
v_cur = ggml_transpose(ctx0, v_cur);
}
//cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}
ggml_tensor * llama_context_kv_self::build_attn_qkv(
ggml_tensor * llama_context_kv_self::build_attn(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * q_cur,
int32_t n_tokens,
float kq_scale,
@@ -2623,7 +2581,42 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
bool worst_case) {
const auto & hparams = model.hparams;
const auto & n_ctx = cparams.n_ctx;
const auto & n_ctx = cparams.n_ctx;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// store to KV cache
{
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
GGML_ASSERT(kv_self.size == n_ctx);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
//cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self.v_l[il]),
(kv_head)*ggml_element_size(kv_self.v_l[il]));
v_cur = ggml_transpose(ctx0, v_cur);
}
//cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}
const auto & n_embd_head_k = hparams.n_embd_head_k;
const auto & n_embd_head_v = hparams.n_embd_head_v;
@@ -2657,8 +2650,6 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);

View File

@@ -376,20 +376,13 @@ public:
bool swa,
bool worst_case) override;
virtual void build_attn_kv_store(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
int32_t n_tokens,
int64_t il,
bool worst_case) override;
virtual ggml_tensor * build_attn_qkv(
virtual ggml_tensor * build_attn(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * q_cur,
int32_t n_tokens,
float kq_scale,
@@ -443,6 +436,7 @@ protected:
// a recurrent transformer (ie.e RWKV, Mamba)
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
//class llama_context_recurrent : public llama_context {
class llama_context_recurrent : public llama_context_kv_self {
public:
llama_context_recurrent(

View File

@@ -88,20 +88,13 @@ public:
bool swa,
bool worst_case) = 0;
virtual void build_attn_kv_store(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
int32_t n_tokens,
int64_t il,
bool worst_case) = 0;
virtual ggml_tensor * build_attn_qkv(
virtual ggml_tensor * build_attn(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * q_cur,
int32_t n_tokens,
float kq_scale,

View File

@@ -4258,13 +4258,7 @@ struct llm_build_context {
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
//build_kv_store(gf, k_cur, v_cur, il);
lgf->build_attn_kv_store(ctx0, gf, k_cur, v_cur, n_tokens, il, worst_case);
struct ggml_tensor * cur;
//cur = build_kqv(gf, wo, wo_b, q_cur, kq_mask, kq_scale, il);
cur = lgf->build_attn_qkv(ctx0, gf, wo, wo_b, q_cur, n_tokens, kq_scale, il, worst_case);
ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, k_cur, v_cur, q_cur, n_tokens, kq_scale, il, worst_case);
cb(cur, "kqv_out", il);
return cur;