mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : reuse built_attn_mha
ggml-ci
This commit is contained in:
@@ -1721,50 +1721,67 @@ void llama_context::build_attn_inp(
|
|||||||
ggml_tensor * llama_context::build_attn(
|
ggml_tensor * llama_context::build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) {
|
int il) {
|
||||||
const auto & hparams = model.hparams;
|
GGML_UNUSED(il);
|
||||||
|
|
||||||
const auto & n_ctx = cparams.n_ctx;
|
|
||||||
|
|
||||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
||||||
|
|
||||||
const auto & kq_mask = inp.kq_mask_cnv;
|
const auto & kq_mask = inp.kq_mask_cnv;
|
||||||
|
|
||||||
const int64_t n_head = hparams.n_head(il);
|
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
||||||
|
|
||||||
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
||||||
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
||||||
|
|
||||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
||||||
const auto n_kv = n_tokens;
|
|
||||||
|
|
||||||
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
|
||||||
//cb(q, "q", il);
|
//cb(q, "q", il);
|
||||||
|
|
||||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, k_cur, 0, 2, 1, 3));
|
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||||
//cb(k, "k", il);
|
//cb(k, "k", il);
|
||||||
|
|
||||||
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||||
|
//cb(k, "v", il);
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_context::build_attn_mha(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * q,
|
||||||
|
ggml_tensor * k,
|
||||||
|
ggml_tensor * v,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * kq_mask,
|
||||||
|
bool v_trans,
|
||||||
|
float kq_scale) {
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
|
//const int64_t n_head = hparams.n_head(il);
|
||||||
|
//const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
|
|
||||||
|
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
|
||||||
|
const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
|
||||||
|
|
||||||
|
const auto n_tokens = q->ne[1];
|
||||||
|
const auto n_head = q->ne[2];
|
||||||
|
const auto n_kv = k->ne[1];
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
//if (cparams.flash_attn) {
|
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
||||||
if (false) { // TODO: need to pad the batch size to a multiple of GGML_KQ_MASK_PAD
|
|
||||||
GGML_UNUSED(model);
|
GGML_UNUSED(model);
|
||||||
GGML_UNUSED(n_ctx);
|
|
||||||
|
|
||||||
GGML_ASSERT(kq_b == nullptr);
|
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
||||||
|
|
||||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, v_cur, 0, 2, 1, 3));
|
if (v_trans) {
|
||||||
v = ggml_reshape_3d(ctx0, v, n_embd_head_v, n_kv, n_head_kv);
|
v = ggml_transpose(ctx0, v);
|
||||||
|
}
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||||
@@ -1774,7 +1791,6 @@ ggml_tensor * llama_context::build_attn(
|
|||||||
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
||||||
} else {
|
} else {
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
//cb(kq, "kq", il);
|
|
||||||
|
|
||||||
// note: this op tends to require high floating point range
|
// note: this op tends to require high floating point range
|
||||||
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
||||||
@@ -1802,22 +1818,17 @@ ggml_tensor * llama_context::build_attn(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
//cb(kq, "kq_soft_max_ext", il);
|
|
||||||
|
|
||||||
// split cached v into n_head heads
|
if (!v_trans) {
|
||||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens)));
|
// note: avoid this branch
|
||||||
|
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
||||||
v = ggml_reshape_3d(ctx0, v, n_kv, n_embd_head_v, n_head_kv);
|
}
|
||||||
//cb(v, "v", il);
|
|
||||||
|
|
||||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||||
//cb(kqv, "kqv", il);
|
|
||||||
|
|
||||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
//cb(kqv_merged, "kqv_merged", il);
|
|
||||||
|
|
||||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||||
//cb(cur, "kqv_merged_cont", il);
|
|
||||||
|
|
||||||
if (!cparams.offload_kqv) {
|
if (!cparams.offload_kqv) {
|
||||||
// all nodes between the KV store and the attention output are run on the CPU
|
// all nodes between the KV store and the attention output are run on the CPU
|
||||||
@@ -1827,18 +1838,6 @@ ggml_tensor * llama_context::build_attn(
|
|||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
if (wo) {
|
|
||||||
cur = build_lora_mm(ctx0, wo, cur);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wo_b) {
|
|
||||||
//cb(cur, "kqv_wo", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wo_b) {
|
|
||||||
cur = ggml_add(ctx0, cur, wo_b);
|
|
||||||
}
|
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3274,13 +3273,10 @@ void llama_context_kv_self::build_attn_inp(
|
|||||||
ggml_tensor * llama_context_kv_self::build_attn(
|
ggml_tensor * llama_context_kv_self::build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) {
|
int il) {
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
@@ -3290,6 +3286,10 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
|
const auto n_tokens = q_cur->ne[2];
|
||||||
|
|
||||||
|
const bool v_trans = !cparams.flash_attn;
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
GGML_ASSERT(!kv_self.recurrent);
|
GGML_ASSERT(!kv_self.recurrent);
|
||||||
@@ -3308,7 +3308,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
|
|
||||||
struct ggml_tensor * v_cache_view = nullptr;
|
struct ggml_tensor * v_cache_view = nullptr;
|
||||||
|
|
||||||
if (cparams.flash_attn) {
|
if (!v_trans) {
|
||||||
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
|
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
|
||||||
} else {
|
} else {
|
||||||
// note: the V cache is transposed when not using flash attention
|
// note: the V cache is transposed when not using flash attention
|
||||||
@@ -3351,16 +3351,15 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
|
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self.n;
|
||||||
|
|
||||||
const int64_t n_head = hparams.n_head(il);
|
|
||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
|
|
||||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||||
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
|
||||||
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||||
//cb(q, "q", il);
|
//cb(q, "q", il);
|
||||||
|
|
||||||
struct ggml_tensor * k =
|
ggml_tensor * k =
|
||||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||||
n_embd_head_k, n_kv, n_head_kv,
|
n_embd_head_k, n_kv, n_head_kv,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
@@ -3368,100 +3367,19 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
0);
|
0);
|
||||||
//cb(k, "k", il);
|
//cb(k, "k", il);
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
ggml_tensor * v = !v_trans ?
|
||||||
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||||
|
n_embd_head_v, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
||||||
|
0) :
|
||||||
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||||
|
n_kv, n_embd_head_v, n_head_kv,
|
||||||
|
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||||
|
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||||
|
0);
|
||||||
|
|
||||||
if (cparams.flash_attn) {
|
struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
|
||||||
GGML_UNUSED(model);
|
|
||||||
GGML_UNUSED(n_ctx);
|
|
||||||
|
|
||||||
GGML_ASSERT(kq_b == nullptr);
|
|
||||||
|
|
||||||
// split cached v into n_head heads (not transposed)
|
|
||||||
struct ggml_tensor * v =
|
|
||||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
|
||||||
n_embd_head_v, n_kv, n_head_kv,
|
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
|
||||||
0);
|
|
||||||
//cb(v, "v", il);
|
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
|
||||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
|
||||||
|
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
|
||||||
} else {
|
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
||||||
//cb(kq, "kq", il);
|
|
||||||
|
|
||||||
// note: this op tends to require high floating point range
|
|
||||||
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
|
||||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_GROK) {
|
|
||||||
// need to do the following:
|
|
||||||
// multiply by attn_output_multiplyer of 0.08838834764831845
|
|
||||||
// and then :
|
|
||||||
// kq = 30 * tanh(kq / 30)
|
|
||||||
// before the softmax below
|
|
||||||
|
|
||||||
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
|
|
||||||
kq = ggml_scale(ctx0, kq, 30);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hparams.attn_soft_cap) {
|
|
||||||
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
|
||||||
kq = ggml_tanh (ctx0, kq);
|
|
||||||
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (kq_b) {
|
|
||||||
kq = ggml_add(ctx0, kq, kq_b);
|
|
||||||
}
|
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
|
||||||
//cb(kq, "kq_soft_max_ext", il);
|
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.size == n_ctx);
|
|
||||||
|
|
||||||
// split cached v into n_head heads
|
|
||||||
struct ggml_tensor * v =
|
|
||||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
|
||||||
n_kv, n_embd_head_v, n_head_kv,
|
|
||||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
|
||||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
|
||||||
0);
|
|
||||||
//cb(v, "v", il);
|
|
||||||
|
|
||||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
|
||||||
//cb(kqv, "kqv", il);
|
|
||||||
|
|
||||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
||||||
//cb(kqv_merged, "kqv_merged", il);
|
|
||||||
|
|
||||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
|
||||||
//cb(cur, "kqv_merged_cont", il);
|
|
||||||
|
|
||||||
if (!cparams.offload_kqv) {
|
|
||||||
// all nodes between the KV store and the attention output are run on the CPU
|
|
||||||
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
|
||||||
|
|
||||||
if (wo) {
|
|
||||||
cur = build_lora_mm(ctx0, wo, cur);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wo_b) {
|
|
||||||
//cb(cur, "kqv_wo", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wo_b) {
|
|
||||||
cur = ggml_add(ctx0, cur, wo_b);
|
|
||||||
}
|
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -261,17 +261,25 @@ public:
|
|||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) override;
|
int il) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
virtual ggml_tensor * build_attn_mha(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * q,
|
||||||
|
ggml_tensor * k,
|
||||||
|
ggml_tensor * v,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * kq_mask,
|
||||||
|
bool v_trans,
|
||||||
|
float kq_scale);
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_self_k_shift(
|
virtual ggml_tensor * build_inp_self_k_shift(
|
||||||
ggml_context * ctx0);
|
ggml_context * ctx0);
|
||||||
|
|
||||||
@@ -472,13 +480,10 @@ public:
|
|||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) override;
|
int il) override;
|
||||||
|
|
||||||
|
|||||||
@@ -7,24 +7,18 @@ llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
|
|||||||
ggml_tensor * llama_graph_i::build_attn(
|
ggml_tensor * llama_graph_i::build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) {
|
int il) {
|
||||||
GGML_UNUSED(ctx0);
|
GGML_UNUSED(ctx0);
|
||||||
GGML_UNUSED(gf);
|
GGML_UNUSED(gf);
|
||||||
GGML_UNUSED(wo);
|
|
||||||
GGML_UNUSED(wo_b);
|
|
||||||
GGML_UNUSED(q_cur);
|
GGML_UNUSED(q_cur);
|
||||||
GGML_UNUSED(k_cur);
|
GGML_UNUSED(k_cur);
|
||||||
GGML_UNUSED(v_cur);
|
GGML_UNUSED(v_cur);
|
||||||
GGML_UNUSED(kq_b);
|
GGML_UNUSED(kq_b);
|
||||||
GGML_UNUSED(n_tokens);
|
|
||||||
GGML_UNUSED(kq_scale);
|
GGML_UNUSED(kq_scale);
|
||||||
GGML_UNUSED(il);
|
GGML_UNUSED(il);
|
||||||
|
|
||||||
|
|||||||
@@ -107,13 +107,10 @@ public:
|
|||||||
virtual ggml_tensor * build_attn(
|
virtual ggml_tensor * build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il);
|
int il);
|
||||||
|
|
||||||
|
|||||||
@@ -4265,18 +4265,32 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * q_cur,
|
struct ggml_tensor * q_cur,
|
||||||
struct ggml_tensor * k_cur,
|
struct ggml_tensor * k_cur,
|
||||||
struct ggml_tensor * v_cur,
|
struct ggml_tensor * v_cur,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens, // TODO: remove
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) {
|
int il) {
|
||||||
|
GGML_UNUSED(n_tokens);
|
||||||
|
|
||||||
// these nodes are added to the graph together so that they are not reordered
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
// by doing so, the number of splits in the graph is reduced
|
// by doing so, the number of splits in the graph is reduced
|
||||||
ggml_build_forward_expand(gf, q_cur);
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, nullptr, n_tokens, kq_scale, il);
|
ggml_tensor * cur = lgf->build_attn(ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
|
if (wo) {
|
||||||
|
cur = lgf->build_lora_mm(ctx0, wo, cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
//cb(cur, "kqv_wo", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
cur = ggml_add(ctx0, cur, wo_b);
|
||||||
|
}
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4288,18 +4302,32 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * k_cur,
|
struct ggml_tensor * k_cur,
|
||||||
struct ggml_tensor * v_cur,
|
struct ggml_tensor * v_cur,
|
||||||
struct ggml_tensor * kq_b,
|
struct ggml_tensor * kq_b,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens, // TODO: remove
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) {
|
int il) {
|
||||||
|
GGML_UNUSED(n_tokens);
|
||||||
|
|
||||||
// these nodes are added to the graph together so that they are not reordered
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
// by doing so, the number of splits in the graph is reduced
|
// by doing so, the number of splits in the graph is reduced
|
||||||
ggml_build_forward_expand(gf, q_cur);
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, n_tokens, kq_scale, il);
|
ggml_tensor * cur = lgf->build_attn(ctx0, gf, q_cur, k_cur, v_cur, kq_b, kq_scale, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
|
if (wo) {
|
||||||
|
cur = lgf->build_lora_mm(ctx0, wo, cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
//cb(cur, "kqv_wo", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
cur = ggml_add(ctx0, cur, wo_b);
|
||||||
|
}
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user