mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-09 10:17:06 +00:00
graph : simplify attention api
ggml-ci
This commit is contained in:
@@ -2567,23 +2567,29 @@ void llama_context_kv_self::build_attn_inp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context_kv_self::build_attn_kv_store(
|
ggml_tensor * llama_context_kv_self::build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur,
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int64_t il,
|
float kq_scale,
|
||||||
|
int il,
|
||||||
bool worst_case) {
|
bool worst_case) {
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const auto & n_ctx = cparams.n_ctx;
|
const auto & n_ctx = cparams.n_ctx;
|
||||||
|
|
||||||
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
|
|
||||||
|
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
|
// store to KV cache
|
||||||
|
{
|
||||||
|
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.size == n_ctx);
|
GGML_ASSERT(kv_self.size == n_ctx);
|
||||||
|
|
||||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
|
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
|
||||||
@@ -2609,21 +2615,8 @@ void llama_context_kv_self::build_attn_kv_store(
|
|||||||
//cb(v_cache_view, "v_cache_view", il);
|
//cb(v_cache_view, "v_cache_view", il);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_context_kv_self::build_attn_qkv(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * gf,
|
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
|
||||||
int il,
|
|
||||||
bool worst_case) {
|
|
||||||
const auto & hparams = model.hparams;
|
|
||||||
|
|
||||||
const auto & n_ctx = cparams.n_ctx;
|
|
||||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||||
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
|
||||||
@@ -2657,8 +2650,6 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
|
|||||||
|
|
||||||
const int64_t n_head = hparams.n_head(il);
|
const int64_t n_head = hparams.n_head(il);
|
||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
||||||
|
|
||||||
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||||
//cb(q, "q", il);
|
//cb(q, "q", il);
|
||||||
|
|||||||
@@ -376,20 +376,13 @@ public:
|
|||||||
bool swa,
|
bool swa,
|
||||||
bool worst_case) override;
|
bool worst_case) override;
|
||||||
|
|
||||||
virtual void build_attn_kv_store(
|
virtual ggml_tensor * build_attn(
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * gf,
|
|
||||||
ggml_tensor * k_cur,
|
|
||||||
ggml_tensor * v_cur,
|
|
||||||
int32_t n_tokens,
|
|
||||||
int64_t il,
|
|
||||||
bool worst_case) override;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_attn_qkv(
|
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
ggml_tensor * wo,
|
||||||
ggml_tensor * wo_b,
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
@@ -443,6 +436,7 @@ protected:
|
|||||||
|
|
||||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||||
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
||||||
|
//class llama_context_recurrent : public llama_context {
|
||||||
class llama_context_recurrent : public llama_context_kv_self {
|
class llama_context_recurrent : public llama_context_kv_self {
|
||||||
public:
|
public:
|
||||||
llama_context_recurrent(
|
llama_context_recurrent(
|
||||||
|
|||||||
@@ -88,20 +88,13 @@ public:
|
|||||||
bool swa,
|
bool swa,
|
||||||
bool worst_case) = 0;
|
bool worst_case) = 0;
|
||||||
|
|
||||||
virtual void build_attn_kv_store(
|
virtual ggml_tensor * build_attn(
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * gf,
|
|
||||||
ggml_tensor * k_cur,
|
|
||||||
ggml_tensor * v_cur,
|
|
||||||
int32_t n_tokens,
|
|
||||||
int64_t il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_attn_qkv(
|
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
ggml_tensor * wo,
|
||||||
ggml_tensor * wo_b,
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
|
|||||||
@@ -4258,13 +4258,7 @@ struct llm_build_context {
|
|||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
//build_kv_store(gf, k_cur, v_cur, il);
|
ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, k_cur, v_cur, q_cur, n_tokens, kq_scale, il, worst_case);
|
||||||
lgf->build_attn_kv_store(ctx0, gf, k_cur, v_cur, n_tokens, il, worst_case);
|
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
|
||||||
|
|
||||||
//cur = build_kqv(gf, wo, wo_b, q_cur, kq_mask, kq_scale, il);
|
|
||||||
cur = lgf->build_attn_qkv(ctx0, gf, wo, wo_b, q_cur, n_tokens, kq_scale, il, worst_case);
|
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
|
|||||||
Reference in New Issue
Block a user