mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
model : avoid ggml_cont_3d for fused QKV weights (#15662)
* model : avoid ggml_cont_3d for fused QKV weights ggml-ci * kv-cache : make cpy_k and cpy_v implementation more readable ggml-ci * cont : add comments ggml-ci * cont : minor fix [no ci] * cont : one more fix * cont : clarity ggml-ci * kv-cache : require contiguous heads of k_cur and v_cur ggml-ci
This commit is contained in:
@@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
|
||||
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
ggml_tensor * k = layers[ikv].k;
|
||||
|
||||
const int64_t n_embd_head = k_cur->ne[0];
|
||||
const int64_t n_head = k_cur->ne[1];
|
||||
const int64_t n_tokens = k_cur->ne[2];
|
||||
|
||||
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||
const int64_t n_embd_gqa = n_embd_head*n_head;
|
||||
|
||||
if (k->ne[2] > 1) {
|
||||
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
||||
// we can merge dims 0 and 1
|
||||
// TODO: add ggml helper function for this?
|
||||
GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
|
||||
|
||||
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
|
||||
|
||||
const int64_t n_stream = k->ne[2];
|
||||
|
||||
if (n_stream > 1) {
|
||||
const int64_t kv_size = get_size();
|
||||
|
||||
assert(n_embd_gqa == k->ne[0]);
|
||||
assert(kv_size == k->ne[1]);
|
||||
|
||||
// merge the buffer across all streams because the idxs are global
|
||||
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
|
||||
}
|
||||
|
||||
// store the current K values into the cache
|
||||
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||
}
|
||||
|
||||
@@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
|
||||
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
|
||||
const int64_t n_embd_head = v_cur->ne[0];
|
||||
const int64_t n_head = v_cur->ne[1];
|
||||
const int64_t n_tokens = v_cur->ne[2];
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||
const int64_t n_embd_gqa = n_embd_head*n_head;
|
||||
|
||||
// we can merge dims 0 and 1
|
||||
GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
|
||||
|
||||
const int64_t n_stream = v->ne[2];
|
||||
|
||||
// take this branch when FA is enabled (the V cache is not transposed)
|
||||
if (!v_trans) {
|
||||
if (v->ne[2] > 1) {
|
||||
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
||||
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
|
||||
|
||||
if (n_stream > 1) {
|
||||
const int64_t kv_size = get_size();
|
||||
|
||||
assert(n_embd_gqa == v->ne[0]);
|
||||
assert(kv_size == v->ne[1]);
|
||||
|
||||
// merge the buffer across all streams because the idxs are global
|
||||
v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
|
||||
}
|
||||
|
||||
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
// [TAG_V_CACHE_VARIABLE]
|
||||
if (n_embd_v_gqa < v->ne[0]) {
|
||||
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
|
||||
if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
|
||||
// we can merge dims 0, 1 and 2
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
|
||||
} else {
|
||||
// otherwise -> make a copy to get contiguous data
|
||||
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
|
||||
}
|
||||
|
||||
// the row becomes a single element
|
||||
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
|
||||
// [TAG_V_CACHE_VARIABLE]
|
||||
if (n_embd_gqa < v->ne[0]) {
|
||||
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
|
||||
}
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
|
||||
// in this branch the v_idxs are constructed in such a way that each row is a single head element
|
||||
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
|
||||
|
||||
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
@@ -317,9 +317,17 @@ public:
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
|
||||
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
|
||||
// - k_idxs [n_tokens]
|
||||
// - v_cur [n_embd_head_v, n_head_v, n_tokens]
|
||||
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||
|
||||
// create destination indices for each head of the current batch for where it would be written in the KV cache
|
||||
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
|
||||
// helps understand the implementation logic of cpy_k and cpy_v
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
|
||||
@@ -6927,9 +6927,7 @@ struct llm_build_falcon : public llm_graph_context {
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_ext(
|
||||
@@ -7207,9 +7205,7 @@ struct llm_build_dbrx : public llm_graph_context {
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
@@ -7329,13 +7325,9 @@ struct llm_build_starcoder : public llm_graph_context {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
@@ -7551,14 +7543,16 @@ struct llm_build_bert : public llm_graph_context {
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
}
|
||||
|
||||
@@ -7569,8 +7563,6 @@ struct llm_build_bert : public llm_graph_context {
|
||||
LLM_NORM, il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
} else {
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
}
|
||||
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
@@ -7580,8 +7572,6 @@ struct llm_build_bert : public llm_graph_context {
|
||||
LLM_NORM, il);
|
||||
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
}
|
||||
|
||||
// RoPE
|
||||
@@ -7727,9 +7717,7 @@ struct llm_build_neo_bert : public llm_graph_context {
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
// RoPE
|
||||
Qcur = ggml_rope_ext(
|
||||
@@ -7836,13 +7824,9 @@ struct llm_build_bloom : public llm_graph_context {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
@@ -7958,13 +7942,9 @@ struct llm_build_mpt : public llm_graph_context {
|
||||
cb(cur, "wqkv_clamped", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
// Q/K Layernorm
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
@@ -7972,26 +7952,16 @@ struct llm_build_mpt : public llm_graph_context {
|
||||
model.layers[il].attn_q_norm,
|
||||
model.layers[il].attn_q_norm_b,
|
||||
LLM_NORM, il);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = build_norm(Kcur,
|
||||
model.layers[il].attn_k_norm,
|
||||
model.layers[il].attn_k_norm_b,
|
||||
LLM_NORM, il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
@@ -8242,9 +8212,7 @@ struct llm_build_qwen : public llm_graph_context {
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd));
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd));
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_ext(
|
||||
@@ -9219,21 +9187,17 @@ struct llm_build_phi2 : public llm_graph_context {
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@@ -9357,21 +9321,17 @@ struct llm_build_phi3 : public llm_graph_context {
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@@ -9621,18 +9581,14 @@ struct llm_build_gpt2 : public llm_graph_context {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
@@ -9727,9 +9683,7 @@ struct llm_build_codeshell : public llm_graph_context {
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
@@ -12601,9 +12555,7 @@ struct llm_build_gptneox : public llm_graph_context {
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
@@ -13736,18 +13688,14 @@ struct llm_build_jais : public llm_graph_context {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il);
|
||||
@@ -13859,8 +13807,7 @@ struct llm_build_chatglm : public llm_graph_context {
|
||||
}
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
}
|
||||
|
||||
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
|
||||
@@ -13993,8 +13940,7 @@ struct llm_build_glm4 : public llm_graph_context {
|
||||
}
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
@@ -17295,14 +17241,12 @@ private:
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
|
||||
ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user