mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
model : avoid ggml_cont_3d for fused QKV weights (#15662)
* model : avoid ggml_cont_3d for fused QKV weights ggml-ci * kv-cache : make cpy_k and cpy_v implementation more readable ggml-ci * cont : add comments ggml-ci * cont : minor fix [no ci] * cont : one more fix * cont : clarity ggml-ci * kv-cache : require contiguous heads of k_cur and v_cur ggml-ci
This commit is contained in:
@@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
|
||||
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
ggml_tensor * k = layers[ikv].k;
|
||||
|
||||
const int64_t n_tokens = k_cur->ne[2];
|
||||
const int64_t n_embd_head = k_cur->ne[0];
|
||||
const int64_t n_head = k_cur->ne[1];
|
||||
const int64_t n_tokens = k_cur->ne[2];
|
||||
|
||||
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||
const int64_t n_embd_gqa = n_embd_head*n_head;
|
||||
|
||||
if (k->ne[2] > 1) {
|
||||
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
||||
// we can merge dims 0 and 1
|
||||
// TODO: add ggml helper function for this?
|
||||
GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
|
||||
|
||||
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
|
||||
|
||||
const int64_t n_stream = k->ne[2];
|
||||
|
||||
if (n_stream > 1) {
|
||||
const int64_t kv_size = get_size();
|
||||
|
||||
assert(n_embd_gqa == k->ne[0]);
|
||||
assert(kv_size == k->ne[1]);
|
||||
|
||||
// merge the buffer across all streams because the idxs are global
|
||||
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
|
||||
}
|
||||
|
||||
// store the current K values into the cache
|
||||
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||
}
|
||||
|
||||
@@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
|
||||
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
|
||||
const int64_t n_tokens = v_cur->ne[2];
|
||||
const int64_t n_embd_head = v_cur->ne[0];
|
||||
const int64_t n_head = v_cur->ne[1];
|
||||
const int64_t n_tokens = v_cur->ne[2];
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||
const int64_t n_embd_gqa = n_embd_head*n_head;
|
||||
|
||||
// we can merge dims 0 and 1
|
||||
GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
|
||||
|
||||
const int64_t n_stream = v->ne[2];
|
||||
|
||||
// take this branch when FA is enabled (the V cache is not transposed)
|
||||
if (!v_trans) {
|
||||
if (v->ne[2] > 1) {
|
||||
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
||||
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
|
||||
|
||||
if (n_stream > 1) {
|
||||
const int64_t kv_size = get_size();
|
||||
|
||||
assert(n_embd_gqa == v->ne[0]);
|
||||
assert(kv_size == v->ne[1]);
|
||||
|
||||
// merge the buffer across all streams because the idxs are global
|
||||
v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
|
||||
}
|
||||
|
||||
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
// [TAG_V_CACHE_VARIABLE]
|
||||
if (n_embd_v_gqa < v->ne[0]) {
|
||||
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
|
||||
if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
|
||||
// we can merge dims 0, 1 and 2
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
|
||||
} else {
|
||||
// otherwise -> make a copy to get contiguous data
|
||||
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
|
||||
}
|
||||
|
||||
// the row becomes a single element
|
||||
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
|
||||
// [TAG_V_CACHE_VARIABLE]
|
||||
if (n_embd_gqa < v->ne[0]) {
|
||||
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
|
||||
}
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
|
||||
// in this branch the v_idxs are constructed in such a way that each row is a single head element
|
||||
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
|
||||
|
||||
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user