From 2eacb4c1bfe01839f579e8aac3068f8758c26874 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 19 Feb 2025 18:43:49 +0200 Subject: [PATCH] graph : simplify attention api ggml-ci --- src/llama-context.cpp | 87 +++++++++++++++++++------------------------ src/llama-context.h | 14 ++----- src/llama-graph.h | 13 ++----- src/llama-model.cpp | 8 +--- 4 files changed, 47 insertions(+), 75 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b571c9343f..818702143e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2567,55 +2567,13 @@ void llama_context_kv_self::build_attn_inp( } } -void llama_context_kv_self::build_attn_kv_store( - ggml_context * ctx0, - ggml_cgraph * gf, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - int32_t n_tokens, - int64_t il, - bool worst_case) { - const auto & hparams = model.hparams; - - const auto & n_ctx = cparams.n_ctx; - - const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - GGML_ASSERT(kv_self.size == n_ctx); - - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head); - //cb(k_cache_view, "k_cache_view", il); - - // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); - - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - - struct ggml_tensor * v_cache_view = nullptr; - - if (cparams.flash_attn) { - v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head); - } else { - // note: the V cache is transposed when not using flash attention - v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv_self.v_l[il]), - (kv_head)*ggml_element_size(kv_self.v_l[il])); - - v_cur = ggml_transpose(ctx0, v_cur); - } - //cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); -} - -ggml_tensor * llama_context_kv_self::build_attn_qkv( +ggml_tensor * llama_context_kv_self::build_attn( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * k_cur, + ggml_tensor * v_cur, ggml_tensor * q_cur, int32_t n_tokens, float kq_scale, @@ -2623,7 +2581,42 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv( bool worst_case) { const auto & hparams = model.hparams; - const auto & n_ctx = cparams.n_ctx; + const auto & n_ctx = cparams.n_ctx; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + // store to KV cache + { + const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + + GGML_ASSERT(kv_self.size == n_ctx); + + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head); + //cb(k_cache_view, "k_cache_view", il); + + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); + + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + + struct ggml_tensor * v_cache_view = nullptr; + + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv_self.v_l[il]), + (kv_head)*ggml_element_size(kv_self.v_l[il])); + + v_cur = ggml_transpose(ctx0, v_cur); + } + //cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + } + const auto & n_embd_head_k = hparams.n_embd_head_k; const auto & n_embd_head_v = hparams.n_embd_head_v; @@ -2657,8 +2650,6 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv( const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); //cb(q, "q", il); diff --git a/src/llama-context.h b/src/llama-context.h index 133eb8b36f..fb241adf1d 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -376,20 +376,13 @@ public: bool swa, bool worst_case) override; - virtual void build_attn_kv_store( - ggml_context * ctx0, - ggml_cgraph * gf, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - int32_t n_tokens, - int64_t il, - bool worst_case) override; - - virtual ggml_tensor * build_attn_qkv( + virtual ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * k_cur, + ggml_tensor * v_cur, ggml_tensor * q_cur, int32_t n_tokens, float kq_scale, @@ -443,6 +436,7 @@ protected: // a recurrent transformer (ie.e RWKV, Mamba) // TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache +//class llama_context_recurrent : public llama_context { class llama_context_recurrent : public llama_context_kv_self { public: llama_context_recurrent( diff --git a/src/llama-graph.h b/src/llama-graph.h index b9456e3d1c..9adfc6f231 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -88,20 +88,13 @@ public: bool swa, bool worst_case) = 0; - virtual void build_attn_kv_store( - ggml_context * ctx0, - ggml_cgraph * gf, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - int32_t n_tokens, - int64_t il, - bool worst_case) = 0; - - virtual ggml_tensor * build_attn_qkv( + virtual ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * k_cur, + ggml_tensor * v_cur, ggml_tensor * q_cur, int32_t n_tokens, float kq_scale, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 09fd63f61c..a22720c3ad 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4258,13 +4258,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - //build_kv_store(gf, k_cur, v_cur, il); - lgf->build_attn_kv_store(ctx0, gf, k_cur, v_cur, n_tokens, il, worst_case); - - struct ggml_tensor * cur; - - //cur = build_kqv(gf, wo, wo_b, q_cur, kq_mask, kq_scale, il); - cur = lgf->build_attn_qkv(ctx0, gf, wo, wo_b, q_cur, n_tokens, kq_scale, il, worst_case); + ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, k_cur, v_cur, q_cur, n_tokens, kq_scale, il, worst_case); cb(cur, "kqv_out", il); return cur;