model : avoid ggml_cont_3d for fused QKV weights (#15662)

* model : avoid ggml_cont_3d for fused QKV weights

ggml-ci

* kv-cache : make cpy_k and cpy_v implementation more readable

ggml-ci

* cont : add comments

ggml-ci

* cont : minor fix [no ci]

* cont : one more fix

* cont : clarity

ggml-ci

* kv-cache : require contiguous heads of k_cur and v_cur

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-09-08 10:25:33 +03:00
committed by GitHub
parent d413dca003
commit cf0e3ba150
3 changed files with 100 additions and 108 deletions

View File

@@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
ggml_tensor * k = layers[ikv].k;
const int64_t n_tokens = k_cur->ne[2];
const int64_t n_embd_head = k_cur->ne[0];
const int64_t n_head = k_cur->ne[1];
const int64_t n_tokens = k_cur->ne[2];
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
const int64_t n_embd_gqa = n_embd_head*n_head;
if (k->ne[2] > 1) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
// we can merge dims 0 and 1
// TODO: add ggml helper function for this?
GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
const int64_t n_stream = k->ne[2];
if (n_stream > 1) {
const int64_t kv_size = get_size();
assert(n_embd_gqa == k->ne[0]);
assert(kv_size == k->ne[1]);
// merge the buffer across all streams because the idxs are global
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
}
// store the current K values into the cache
return ggml_set_rows(ctx, k, k_cur, k_idxs);
}
@@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
auto * v = layers[ikv].v;
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
const int64_t n_tokens = v_cur->ne[2];
const int64_t n_embd_head = v_cur->ne[0];
const int64_t n_head = v_cur->ne[1];
const int64_t n_tokens = v_cur->ne[2];
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
const int64_t n_embd_gqa = n_embd_head*n_head;
// we can merge dims 0 and 1
GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
const int64_t n_stream = v->ne[2];
// take this branch when FA is enabled (the V cache is not transposed)
if (!v_trans) {
if (v->ne[2] > 1) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
if (n_stream > 1) {
const int64_t kv_size = get_size();
assert(n_embd_gqa == v->ne[0]);
assert(kv_size == v->ne[1]);
// merge the buffer across all streams because the idxs are global
v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
}
return ggml_set_rows(ctx, v, v_cur, v_idxs);
}
// [TAG_V_CACHE_VARIABLE]
if (n_embd_v_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
// we can merge dims 0, 1 and 2
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
} else {
// otherwise -> make a copy to get contiguous data
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
// [TAG_V_CACHE_VARIABLE]
if (n_embd_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
}
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
// in this branch the v_idxs are constructed in such a way that each row is a single head element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}