From 372fa3a894757cdd844a27141c6396718fce4f4c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 23 Feb 2025 11:38:59 +0200 Subject: [PATCH] cont : enc should work now, next is dec ggml-ci --- src/llama-context.cpp | 188 +++++++++++++++++++---------- src/llama-context.h | 41 ++++--- src/llama-graph.cpp | 2 + src/llama-graph.h | 5 + src/llama-model.cpp | 274 +++++++++++++++++++++--------------------- 5 files changed, 293 insertions(+), 217 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9b341aa182..d98f4662c2 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -10,21 +10,64 @@ #include #include +// +// helpers +// + +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + + return relative_bucket; +} + // // llama_context // llama_context::llama_context( const llama_model & model, - const llama_context_params & params, + llama_context_params params, llama_graph_type gtype) : llama_graph_i(gtype), model(model) { - LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); + LLAMA_LOG_INFO("%s: constructing llama_context, gtype = %d\n", __func__, gtype); t_start_us = model.t_start_us; t_load_us = model.t_load_us; + switch (gtype) { + case LLAMA_GRAPH_TYPE_DEFAULT: + case LLAMA_GRAPH_TYPE_DECODER: + { + } break; + case LLAMA_GRAPH_TYPE_ENCODER: + { + params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; + params.embeddings = true; + } break; + } + const auto & hparams = model.hparams; cparams.n_seq_max = std::max(1u, params.n_seq_max); @@ -45,20 +88,6 @@ llama_context::llama_context( cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; - // with causal attention, the batch size is limited by the context size - cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - - // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask - // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) - // ref: https://github.com/ggerganov/llama.cpp/pull/5021 - // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self - if (cparams.n_batch < GGML_KQ_MASK_PAD) { - LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); - cparams.n_batch = GGML_KQ_MASK_PAD; - } - - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : hparams.n_ctx_train; @@ -95,6 +124,20 @@ llama_context::llama_context( cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } + // with causal attention, the batch size is limited by the context size + cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; + + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); @@ -102,6 +145,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -1207,6 +1251,23 @@ void llama_context::input_set(const llama_ubatch & ubatch) { } } + if (inp.pos_bucket) { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp.pos_bucket->buffer)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + + int32_t * data = (int32_t *) inp.pos_bucket->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, true); + } + } + } + } + GGML_ASSERT( // (!a || b) is a logical implication (a -> b) // !hparams.causal_attn -> !cparams.causal_attn @@ -1604,6 +1665,15 @@ ggml_tensor * llama_context::build_inp_pos( return inp.pos; } +ggml_tensor * llama_context::build_inp_pos_bucket( + ggml_context * ctx0, + int32_t n_tokens) { + inp.pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); + ggml_set_input(inp.pos_bucket); + + return inp.pos_bucket; +} + ggml_tensor * llama_context::build_inp_out_ids( ggml_context * ctx0) { const int32_t n_out_ids = n_outputs; @@ -1656,6 +1726,7 @@ ggml_tensor * llama_context::build_attn( ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, int32_t n_tokens, float kq_scale, int il) { @@ -1690,6 +1761,8 @@ ggml_tensor * llama_context::build_attn( GGML_UNUSED(model); GGML_UNUSED(n_ctx); + GGML_ASSERT(kq_b == nullptr); + struct ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, v_cur, 0, 2, 1, 3)); v = ggml_reshape_3d(ctx0, v, n_embd_head_v, n_kv, n_head_kv); @@ -1720,10 +1793,14 @@ ggml_tensor * llama_context::build_attn( if (hparams.attn_soft_cap) { kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); - kq = ggml_tanh(ctx0, kq); + kq = ggml_tanh (ctx0, kq); kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); } + if (kq_b) { + kq = ggml_add(ctx0, kq, kq_b); + } + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); //cb(kq, "kq_soft_max_ext", il); @@ -2281,7 +2358,7 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_ llama_context_kv_self::llama_context_kv_self( const llama_model & model, - const llama_context_params & params, + llama_context_params params, llama_graph_type gtype) : llama_context(model, params, gtype), kv_self(model.hparams) { @@ -3053,53 +3130,19 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { } } - if (inp_pos_bucket) { + if (inp.self_pos_bucket) { const int64_t n_tokens = ubatch.n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(inp_pos_bucket->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_pos_bucket->buffer)); GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing - static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { - // TODO move to hparams if a T5 variant appears that uses a different value - const int64_t max_distance = 128; + int32_t * data = (int32_t *) inp.self_pos_bucket->data; - if (bidirectional) { - n_buckets >>= 1; - } - - const int64_t max_exact = n_buckets >> 1; - - int32_t relative_position = x - y; - int32_t relative_bucket = 0; - if (bidirectional) { - relative_bucket += (relative_position > 0) * n_buckets; - relative_position = abs(relative_position); - } else { - relative_position = -std::min(relative_position, 0); - } - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); - relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); - return relative_bucket; - }; - - int32_t * data = (int32_t *) inp_pos_bucket->data; - - if (!is_encoding) { - const int64_t n_kv = kv_self.n; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding); - } - } - } - } else { - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding); - } + const int64_t n_kv = kv_self.n; + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false); } } } @@ -3146,7 +3189,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { ggml_cgraph * llama_context_kv_self::graph_init() { inp_embd_enc = nullptr; - inp_pos_bucket = nullptr; inp_kq_mask_cross = nullptr; inp = {}; @@ -3161,6 +3203,17 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) return inp.self_k_shift; } +ggml_tensor * llama_context_kv_self::build_inp_pos_bucket( + ggml_context * ctx0, + int32_t n_tokens) { + const auto n_kv = kv_self.n; + + inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + ggml_set_input(inp.self_pos_bucket); + + return inp.self_pos_bucket; +} + void llama_context_kv_self::build_attn_inp( ggml_context * ctx0, int32_t n_tokens, @@ -3199,6 +3252,7 @@ ggml_tensor * llama_context_kv_self::build_attn( ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, int32_t n_tokens, float kq_scale, int il) { @@ -3293,6 +3347,8 @@ ggml_tensor * llama_context_kv_self::build_attn( GGML_UNUSED(model); GGML_UNUSED(n_ctx); + GGML_ASSERT(kq_b == nullptr); + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx0, kv_self.v_l[il], @@ -3329,10 +3385,14 @@ ggml_tensor * llama_context_kv_self::build_attn( if (hparams.attn_soft_cap) { kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); - kq = ggml_tanh(ctx0, kq); + kq = ggml_tanh (ctx0, kq); kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); } + if (kq_b) { + kq = ggml_add(ctx0, kq, kq_b); + } + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); //cb(kq, "kq_soft_max_ext", il); @@ -3753,7 +3813,7 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq llama_context_recurrent::llama_context_recurrent( const llama_model & model, - const llama_context_params & params, + llama_context_params params, llama_graph_type gtype) : llama_context(model, params, gtype), kv_self(model.hparams) { @@ -4629,7 +4689,7 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s llama_context_enc_dec::llama_context_enc_dec( const llama_model & model, - const llama_context_params & params) : + llama_context_params params) : llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER), ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) { LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__); diff --git a/src/llama-context.h b/src/llama-context.h index 7cc982e10b..3e9baabfb5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -25,7 +25,7 @@ struct llama_context : public llama_graph_i { public: llama_context( const llama_model & model, - const llama_context_params & params, + llama_context_params params, llama_graph_type gtype); virtual ~llama_context(); @@ -142,12 +142,13 @@ protected: struct { // base input tensors - ggml_tensor * tokens; // I32 [n_batch] - ggml_tensor * embd; // F32 [n_embd, n_batch] - ggml_tensor * pos; // I32 [n_batch] - ggml_tensor * out_ids; // I32 [n_outputs] - ggml_tensor * mean; // F32 [n_batch, n_batch] - ggml_tensor * cls; // I32 [n_batch] + ggml_tensor * tokens; // I32 [n_batch] + ggml_tensor * embd; // F32 [n_embd, n_batch] + ggml_tensor * pos; // I32 [n_batch] + ggml_tensor * pos_bucket; // I32 [n_batch, n_batch] + ggml_tensor * out_ids; // I32 [n_outputs] + ggml_tensor * mean; // F32 [n_batch, n_batch] + ggml_tensor * cls; // I32 [n_batch] // KQ mask input tensors ggml_tensor * kq_mask; // F32 [n_tokens, n_batch] @@ -233,6 +234,10 @@ protected: ggml_context * ctx0, int32_t n_tokens); + virtual ggml_tensor * build_inp_pos_bucket( + ggml_context * ctx0, + int32_t n_tokens); + virtual ggml_tensor * build_inp_out_ids( ggml_context * ctx0); @@ -258,6 +263,7 @@ protected: ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, int32_t n_tokens, float kq_scale, int il); @@ -389,7 +395,7 @@ class llama_context_kv_self : public llama_context { public: llama_context_kv_self( const llama_model & model, - const llama_context_params & params, + llama_context_params params, llama_graph_type gtype); virtual ~llama_context_kv_self(); @@ -414,10 +420,11 @@ protected: virtual void input_set(const llama_ubatch & ubatch) override; struct { - ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch] - ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch] - ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch] - ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch] + ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch] + ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv; // [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa_cnv; // [n_kv, n_batch] ggml_tensor * self_k_shift; // I32 [kv_size] } inp; @@ -433,6 +440,10 @@ protected: virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override; + virtual ggml_tensor * build_inp_pos_bucket( + ggml_context * ctx0, + int32_t n_tokens) override; + virtual void build_attn_inp( ggml_context * ctx0, int32_t n_tokens, @@ -447,6 +458,7 @@ protected: ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, int32_t n_tokens, float kq_scale, int il) override; @@ -470,7 +482,6 @@ protected: std::vector> seq_ids_enc; struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] - struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch] virtual ggml_tensor * build_inp_embd_enc( @@ -502,7 +513,7 @@ class llama_context_recurrent : public llama_context { public: llama_context_recurrent( const llama_model & model, - const llama_context_params & params, + llama_context_params params, llama_graph_type gtype); virtual ~llama_context_recurrent(); @@ -616,7 +627,7 @@ class llama_context_enc_dec : public llama_context { public: llama_context_enc_dec( const llama_model & model, - const llama_context_params & params); + llama_context_params params); virtual ~llama_context_enc_dec(); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index af2c94be7f..3ac96908d6 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -12,6 +12,7 @@ ggml_tensor * llama_graph_i::build_attn( ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, int32_t n_tokens, float kq_scale, int il) { @@ -22,6 +23,7 @@ ggml_tensor * llama_graph_i::build_attn( GGML_UNUSED(q_cur); GGML_UNUSED(k_cur); GGML_UNUSED(v_cur); + GGML_UNUSED(kq_b); GGML_UNUSED(n_tokens); GGML_UNUSED(kq_scale); GGML_UNUSED(il); diff --git a/src/llama-graph.h b/src/llama-graph.h index 82d2dc7362..5df90e76d5 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -83,6 +83,10 @@ public: ggml_context * ctx0, int32_t n_tokens) = 0; + virtual ggml_tensor * build_inp_pos_bucket( + ggml_context * ctx0, + int32_t n_tokens) = 0; + virtual ggml_tensor * build_inp_out_ids( ggml_context * ctx0) = 0; @@ -108,6 +112,7 @@ public: ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, int32_t n_tokens, float kq_scale, int il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c862502d3c..1e34ed8038 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1432,7 +1432,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // skip unused tensors if (info.op == GGML_OP_NONE) { - LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str()); + const size_t nbytes = ggml_nbytes(t_meta); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); + + ml.size_data -= nbytes; ml.n_created++; return nullptr; @@ -3952,6 +3955,14 @@ struct llm_build_context { return lgf->build_lora_mm_id(ctx0, w, cur, ids); } + // TODO: tmp + struct ggml_tensor * build_pos_bucket() { + ggml_tensor * cur = lgf->build_inp_pos_bucket(ctx0, n_tokens); + cb(cur, "pos_bucket", -1); + + return cur; + } + // TODO: tmp struct ggml_tensor * build_inp_embd_enc() { ggml_tensor * cur = lgf->build_inp_embd_enc(ctx0); @@ -4263,7 +4274,30 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, n_tokens, kq_scale, il); + ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, nullptr, n_tokens, kq_scale, il); + cb(cur, "kqv_out", il); + + return cur; + } + + struct ggml_tensor * build_attn_with_kq_b( + struct ggml_cgraph * gf, + struct ggml_tensor * wo, + struct ggml_tensor * wo_b, + struct ggml_tensor * q_cur, + struct ggml_tensor * k_cur, + struct ggml_tensor * v_cur, + struct ggml_tensor * kq_b, + int32_t n_tokens, + float kq_scale, + int il) { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, n_tokens, kq_scale, il); cb(cur, "kqv_out", il); return cur; @@ -4364,37 +4398,24 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); } - //struct ggml_tensor * build_pos_bucket(bool causal) { - // if (causal) { - // lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); - // } else { - // lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); - // } + struct ggml_tensor * build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) { + struct ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]); + cb(pos_bucket_1d, "pos_bucket_1d", -1); - // ggml_set_input(lctx.inp_pos_bucket); - // cb(lctx.inp_pos_bucket, "pos_bucket", -1); + struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d); + cb(pos_bias, "pos_bias", -1); - // return lctx.inp_pos_bucket; - //} + pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]); + cb(pos_bias, "pos_bias", -1); - //struct ggml_tensor * build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) { - // struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0); - // cb(pos_bucket_1d, "pos_bucket_1d", -1); + pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3); + cb(pos_bias, "pos_bias", -1); - // struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d); - // cb(pos_bias, "pos_bias", -1); + pos_bias = ggml_cont(ctx0, pos_bias); + cb(pos_bias, "pos_bias", -1); - // pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0); - // cb(pos_bias, "pos_bias", -1); - - // pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3); - // cb(pos_bias, "pos_bias", -1); - - // pos_bias = ggml_cont(ctx0, pos_bias); - // cb(pos_bias, "pos_bias", -1); - - // return pos_bias; - //} + return pos_bias; + } void build_llama(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -9614,132 +9635,104 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); } - //void build_t5_enc(ggml_cgraph * gf) { - // const int64_t n_embd_head = hparams.n_embd_head_v; - // const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + void build_t5_enc(ggml_cgraph * gf) { + const int64_t n_embd_head = hparams.n_embd_head_v; - // GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // struct ggml_tensor * cur; - // struct ggml_tensor * inpL; + struct ggml_tensor * cur; + struct ggml_tensor * inpL; - // inpL = build_inp_embd(model.tok_embd); + inpL = build_inp_embd(model.tok_embd); - // GGML_ASSERT(lctx.is_encoding); - // struct ggml_tensor * pos_bucket_enc = build_pos_bucket(false); + struct ggml_tensor * pos_bucket_enc = build_pos_bucket(); - // // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - // struct ggml_tensor * KQ_mask_enc = build_inp_kq_mask(false); + lgf->build_attn_inp(ctx0, n_tokens, false, false); - // for (int il = 0; il < n_layer; ++il) { - // struct ggml_tensor * inpSA = inpL; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; - // // norm - // cur = build_norm(inpL, - // model.layers[il].attn_norm_enc, NULL, - // LLM_NORM_RMS, il); - // cb(cur, "attn_norm", il); + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - // // self-attention - // { - // struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); - // cb(Qcur, "Qcur", il); + // self-attention + { + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); - // struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); - // cb(Kcur, "Kcur", il); + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); - // struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); - // cb(Vcur, "Vcur", il); + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); - // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - // Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - // struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + struct ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - // struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - // cb(kq, "kq", il); + cur = build_attn_with_kq_b(gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, kq_b, n_tokens, 1.0f, il); + cb(cur, "kqv_out", il); + } - // struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; - // struct ggml_tensor * pos_bias = build_pos_bias(pos_bucket_enc, attn_rel_b); - // struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); - // cb(kq_b, "kq_b", il); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } - // kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias); - // cb(kq, "kq_soft_max_ext", il); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - // struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); - // cb(v, "v", il); + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); - // struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); - // cb(kqv, "kqv", il); + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } - // struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - // cb(kqv_merged, "kqv_merged", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); - // cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - // cb(cur, "kqv_merged_cont", il); + cur = lgf->build_cvec(ctx0, cur, il); + cb(cur, "l_out", il); - // ggml_build_forward_expand(gf, cur); + // input for next layer + inpL = cur; + } - // cur = build_lora_mm(model.layers[il].wo_enc, cur); - // cb(cur, "kqv_out", il); - // } + cur = inpL; + cb(cur, "result_embd", -1); - // if (il == n_layer - 1) { - // // skip computing output for unused tokens - // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - // cur = ggml_get_rows(ctx0, cur, inp_out_ids); - // inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - // } + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); - // struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - // cb(ffn_inp, "ffn_inp", il); + cb(cur, "result_norm", -1); + res.t_embd = cur; - // // feed-forward network - // { - // cur = build_norm(ffn_inp, - // model.layers[il].ffn_norm_enc, NULL, - // LLM_NORM_RMS, il); - // cb(cur, "ffn_norm", il); - - // // T5 uses relu, flan-T5 uses gelu-gated - // cur = build_ffn(cur, - // model.layers[il].ffn_up_enc, NULL, NULL, - // model.layers[il].ffn_gate_enc, NULL, NULL, - // model.layers[il].ffn_down_enc, NULL, NULL, - // NULL, - // model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - // model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - // il); - // cb(cur, "ffn_out", il); - // } - - // cur = ggml_add(ctx0, cur, ffn_inp); - // cb(cur, "ffn_out", il); - - // ggml_tensor * layer_dir = cvec.tensor_for(il); - // if (layer_dir != nullptr) { - // cur = ggml_add(ctx0, cur, layer_dir); - // } - // cb(cur, "l_out", il); - - // // input for next layer - // inpL = cur; - // } - - // cur = inpL; - // cb(cur, "result_embd", -1); - - // cur = build_norm(cur, - // model.output_norm_enc, NULL, - // LLM_NORM_RMS, -1); - // - // cb(cur, "result_norm", -1); - // res.t_embd = cur; - - // ggml_build_forward_expand(gf, cur); - //} + ggml_build_forward_expand(gf, cur); + } //void build_t5_dec(ggml_cgraph * gf) { // const int64_t n_embd_head = hparams.n_embd_head_v; @@ -11091,14 +11084,19 @@ llama_graph_result llama_model::build_graph( { llm.build_bitnet(gf); } break; - //case LLM_ARCH_T5: - // { - // if (lctx.is_encoding) { - // llm.build_t5_enc(gf); - // } else { - // llm.build_t5_dec(gf); - // } - // } break; + case LLM_ARCH_T5: + { + switch (lgf->get_type()) { + case LLAMA_GRAPH_TYPE_ENCODER: + llm.build_t5_enc(gf); + break; + case LLAMA_GRAPH_TYPE_DECODER: + //llm.build_t5_dec(gf); + break; + default: + GGML_ABORT("invalid graph type"); + }; + } break; //case LLM_ARCH_T5ENCODER: // { // llm.build_t5_enc(gf);