refactor: Remove n_embd_k/v_gqa from recurrent cache

This is no longer needed now that there are separate implementations

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-11 12:56:26 -06:00
parent b42c8b43cf
commit d5d7628b5f

View File

@@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
continue; continue;
} }
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const char * dev_name = "CPU"; const char * dev_name = "CPU";
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
@@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
throw std::runtime_error("failed to create ggml context for kv cache"); throw std::runtime_error("failed to create ggml context for kv cache");
} }
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size);
ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i); ggml_format_name(v, "cache_v_l%d", i);
k_l[i] = k; k_l[i] = k;
@@ -754,14 +751,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
// Iterate and write all the keys first, each row is a cell // Iterate and write all the keys first, each row is a cell
// Get whole range at a time // Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Write key type // Write key type
const int32_t k_type_i = (int32_t)k_l[il]->type; const int32_t k_type_i = (int32_t)k_l[il]->type;
io.write(&k_type_i, sizeof(k_type_i)); io.write(&k_type_i, sizeof(k_type_i));
// Write row size of key // Write row size of key
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s());
io.write(&k_size_row, sizeof(k_size_row)); io.write(&k_size_row, sizeof(k_size_row));
// Read each range of cells of k_size length each into tmp_buf and write out // Read each range of cells of k_size length each into tmp_buf and write out
@@ -774,14 +770,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
if (!v_trans) { if (!v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type // Write value type
const int32_t v_type_i = (int32_t)v_l[il]->type; const int32_t v_type_i = (int32_t)v_l[il]->type;
io.write(&v_type_i, sizeof(v_type_i)); io.write(&v_type_i, sizeof(v_type_i));
// Write row size of value // Write row size of value
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s());
io.write(&v_size_row, sizeof(v_size_row)); io.write(&v_size_row, sizeof(v_size_row));
// Read each range of cells of v_size length each into tmp_buf and write out // Read each range of cells of v_size length each into tmp_buf and write out
@@ -795,7 +790,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
// When v is transposed, we also need the element size and get the element ranges from each row // When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = size; const uint32_t kv_size = size;
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); const uint32_t n_embd_v_s = hparams.n_embd_v_s();
// Write value type // Write value type
const int32_t v_type_i = (int32_t)v_l[il]->type; const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -806,10 +801,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
io.write(&v_size_el, sizeof(v_size_el)); io.write(&v_size_el, sizeof(v_size_el));
// Write GQA embedding size // Write GQA embedding size
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); io.write(&n_embd_v_s, sizeof(n_embd_v_s));
// For each row, we get the element values of each cell // For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { for (uint32_t j = 0; j < n_embd_v_s; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out // Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first; const size_t range_size = range.second - range.first;
@@ -942,7 +937,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Read type of key // Read type of key
int32_t k_type_i_ref; int32_t k_type_i_ref;
@@ -956,7 +950,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
// Read row size of key // Read row size of key
uint64_t k_size_row_ref; uint64_t k_size_row_ref;
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s());
if (k_size_row != k_size_row_ref) { if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
return false; return false;
@@ -970,7 +964,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
if (!v_trans) { if (!v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Read type of value // Read type of value
int32_t v_type_i_ref; int32_t v_type_i_ref;
@@ -984,7 +977,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
// Read row size of value // Read row size of value
uint64_t v_size_row_ref; uint64_t v_size_row_ref;
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s());
if (v_size_row != v_size_row_ref) { if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
return false; return false;
@@ -998,7 +991,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
} else { } else {
// For each layer, read the values for each cell (transposed) // For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); const uint32_t n_embd_v_s = hparams.n_embd_v_s();
// Read type of value // Read type of value
int32_t v_type_i_ref; int32_t v_type_i_ref;
@@ -1018,17 +1011,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
return false; return false;
} }
// Read GQA embedding size // Read state embedding size
uint32_t n_embd_v_gqa_ref; uint32_t n_embd_v_s_ref;
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref));
if (n_embd_v_gqa != n_embd_v_gqa_ref) { if (n_embd_v_s != n_embd_v_s_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il);
return false; return false;
} }
if (cell_count) { if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range // For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { for (uint32_t j = 0; j < n_embd_v_s; ++j) {
const size_t dst_offset = (head + j * size) * v_size_el; const size_t dst_offset = (head + j * size) * v_size_el;
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
} }