From 4a1054b55259cb3d43929121294e0ac28a632435 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 24 Feb 2025 11:18:40 +0200 Subject: [PATCH] context : reuse built_attn_mha ggml-ci --- src/llama-context.cpp | 214 +++++++++++++----------------------------- src/llama-context.h | 17 ++-- src/llama-graph.cpp | 6 -- src/llama-graph.h | 3 - src/llama-model.cpp | 36 ++++++- 5 files changed, 109 insertions(+), 167 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f73d4b9bf4..e05afb5646 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1721,50 +1721,67 @@ void llama_context::build_attn_inp( ggml_tensor * llama_context::build_attn( ggml_context * ctx0, ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - int32_t n_tokens, float kq_scale, int il) { - const auto & hparams = model.hparams; - - const auto & n_ctx = cparams.n_ctx; - - //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + GGML_UNUSED(il); const auto & kq_mask = inp.kq_mask_cnv; - const int64_t n_head = hparams.n_head(il); - const int64_t n_head_kv = hparams.n_head_kv(il); - - //const auto & n_embd_head_k = hparams.n_embd_head_k; - const auto & n_embd_head_v = hparams.n_embd_head_v; - - // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - const auto n_kv = n_tokens; - - struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); //cb(q, "q", il); - struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, k_cur, 0, 2, 1, 3)); + ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); //cb(k, "k", il); + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + //cb(k, "v", il); + + ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, false, kq_scale); + + return cur; +} + +ggml_tensor * llama_context::build_attn_mha( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + bool v_trans, + float kq_scale) { + const auto & hparams = model.hparams; + + //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + //const int64_t n_head = hparams.n_head(il); + //const int64_t n_head_kv = hparams.n_head_kv(il); + + //const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + const auto n_kv = k->ne[1]; + struct ggml_tensor * cur; - //if (cparams.flash_attn) { - if (false) { // TODO: need to pad the batch size to a multiple of GGML_KQ_MASK_PAD + if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { GGML_UNUSED(model); - GGML_UNUSED(n_ctx); - GGML_ASSERT(kq_b == nullptr); + GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); - struct ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, v_cur, 0, 2, 1, 3)); - v = ggml_reshape_3d(ctx0, v, n_embd_head_v, n_kv, n_head_kv); + if (v_trans) { + v = ggml_transpose(ctx0, v); + } cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); @@ -1774,7 +1791,6 @@ ggml_tensor * llama_context::build_attn( cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - //cb(kq, "kq", il); // note: this op tends to require high floating point range // while for some models F16 is enough, for others it is not, so we default to F32 here @@ -1802,22 +1818,17 @@ ggml_tensor * llama_context::build_attn( } kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - //cb(kq, "kq_soft_max_ext", il); - // split cached v into n_head heads - struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens))); - - v = ggml_reshape_3d(ctx0, v, n_kv, n_embd_head_v, n_head_kv); - //cb(v, "v", il); + if (!v_trans) { + // note: avoid this branch + v = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + } struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - //cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - //cb(kqv_merged, "kqv_merged", il); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - //cb(cur, "kqv_merged_cont", il); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1827,18 +1838,6 @@ ggml_tensor * llama_context::build_attn( ggml_build_forward_expand(gf, cur); - if (wo) { - cur = build_lora_mm(ctx0, wo, cur); - } - - if (wo_b) { - //cb(cur, "kqv_wo", il); - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - return cur; } @@ -3274,13 +3273,10 @@ void llama_context_kv_self::build_attn_inp( ggml_tensor * llama_context_kv_self::build_attn( ggml_context * ctx0, ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - int32_t n_tokens, float kq_scale, int il) { const auto & hparams = model.hparams; @@ -3290,6 +3286,10 @@ ggml_tensor * llama_context_kv_self::build_attn( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const auto n_tokens = q_cur->ne[2]; + + const bool v_trans = !cparams.flash_attn; + // store to KV cache { GGML_ASSERT(!kv_self.recurrent); @@ -3308,7 +3308,7 @@ ggml_tensor * llama_context_kv_self::build_attn( struct ggml_tensor * v_cache_view = nullptr; - if (cparams.flash_attn) { + if (!v_trans) { v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head); } else { // note: the V cache is transposed when not using flash attention @@ -3351,16 +3351,15 @@ ggml_tensor * llama_context_kv_self::build_attn( const auto n_kv = kv_self.n; - const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); const auto & n_embd_head_k = hparams.n_embd_head_k; const auto & n_embd_head_v = hparams.n_embd_head_v; - struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); //cb(q, "q", il); - struct ggml_tensor * k = + ggml_tensor * k = ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_kv, n_head_kv, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), @@ -3368,100 +3367,19 @@ ggml_tensor * llama_context_kv_self::build_attn( 0); //cb(k, "k", il); - struct ggml_tensor * cur; + ggml_tensor * v = !v_trans ? + ggml_view_3d(ctx0, kv_self.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), + 0) : + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); - if (cparams.flash_attn) { - GGML_UNUSED(model); - GGML_UNUSED(n_ctx); - - GGML_ASSERT(kq_b == nullptr); - - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), - 0); - //cb(v, "v", il); - - cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, - hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - - cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - //cb(kq, "kq", il); - - // note: this op tends to require high floating point range - // while for some models F16 is enough, for others it is not, so we default to F32 here - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below - - kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx0, kq, 30); - } - - if (hparams.attn_soft_cap) { - kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); - kq = ggml_tanh (ctx0, kq); - kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); - } - - if (kq_b) { - kq = ggml_add(ctx0, kq, kq_b); - } - - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - //cb(kq, "kq_soft_max_ext", il); - - GGML_ASSERT(kv_self.size == n_ctx); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - 0); - //cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - //cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - //cb(kqv_merged, "kqv_merged", il); - - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - //cb(cur, "kqv_merged_cont", il); - - if (!cparams.offload_kqv) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - - ggml_build_forward_expand(gf, cur); - - if (wo) { - cur = build_lora_mm(ctx0, wo, cur); - } - - if (wo_b) { - //cb(cur, "kqv_wo", il); - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } + struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); return cur; } diff --git a/src/llama-context.h b/src/llama-context.h index 2945cbabe4..5b63b3b06d 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -261,17 +261,25 @@ public: ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - int32_t n_tokens, float kq_scale, int il) override; protected: + virtual ggml_tensor * build_attn_mha( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + bool v_trans, + float kq_scale); + virtual ggml_tensor * build_inp_self_k_shift( ggml_context * ctx0); @@ -472,13 +480,10 @@ public: ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - int32_t n_tokens, float kq_scale, int il) override; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c058ee2498..99eb326205 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,24 +7,18 @@ llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {} ggml_tensor * llama_graph_i::build_attn( ggml_context * ctx0, ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - int32_t n_tokens, float kq_scale, int il) { GGML_UNUSED(ctx0); GGML_UNUSED(gf); - GGML_UNUSED(wo); - GGML_UNUSED(wo_b); GGML_UNUSED(q_cur); GGML_UNUSED(k_cur); GGML_UNUSED(v_cur); GGML_UNUSED(kq_b); - GGML_UNUSED(n_tokens); GGML_UNUSED(kq_scale); GGML_UNUSED(il); diff --git a/src/llama-graph.h b/src/llama-graph.h index ee56f08396..c84c254934 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -107,13 +107,10 @@ public: virtual ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - int32_t n_tokens, float kq_scale, int il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1e34ed8038..e8057f4687 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4265,18 +4265,32 @@ struct llm_build_context { struct ggml_tensor * q_cur, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int32_t n_tokens, + int32_t n_tokens, // TODO: remove float kq_scale, int il) { + GGML_UNUSED(n_tokens); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, nullptr, n_tokens, kq_scale, il); + ggml_tensor * cur = lgf->build_attn(ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il); cb(cur, "kqv_out", il); + if (wo) { + cur = lgf->build_lora_mm(ctx0, wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + return cur; } @@ -4288,18 +4302,32 @@ struct llm_build_context { struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, struct ggml_tensor * kq_b, - int32_t n_tokens, + int32_t n_tokens, // TODO: remove float kq_scale, int il) { + GGML_UNUSED(n_tokens); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, n_tokens, kq_scale, il); + ggml_tensor * cur = lgf->build_attn(ctx0, gf, q_cur, k_cur, v_cur, kq_b, kq_scale, il); cb(cur, "kqv_out", il); + if (wo) { + cur = lgf->build_lora_mm(ctx0, wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + return cur; }