mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	refactor: Remove n_embd_k/v_gqa from recurrent cache
This is no longer needed now that there are separate implementations https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
		| @@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( | |||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); |  | ||||||
|         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); |  | ||||||
|  |  | ||||||
|         const char * dev_name = "CPU"; |         const char * dev_name = "CPU"; | ||||||
|  |  | ||||||
|         ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); |         ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); | ||||||
| @@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( | |||||||
|             throw std::runtime_error("failed to create ggml context for kv cache"); |             throw std::runtime_error("failed to create ggml context for kv cache"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); |         ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size); | ||||||
|         ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); |         ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size); | ||||||
|         ggml_format_name(k, "cache_k_l%d", i); |         ggml_format_name(k, "cache_k_l%d", i); | ||||||
|         ggml_format_name(v, "cache_v_l%d", i); |         ggml_format_name(v, "cache_v_l%d", i); | ||||||
|         k_l[i] = k; |         k_l[i] = k; | ||||||
| @@ -754,14 +751,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | |||||||
|     // Iterate and write all the keys first, each row is a cell |     // Iterate and write all the keys first, each row is a cell | ||||||
|     // Get whole range at a time |     // Get whole range at a time | ||||||
|     for (uint32_t il = 0; il < n_layer; ++il) { |     for (uint32_t il = 0; il < n_layer; ++il) { | ||||||
|         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); |  | ||||||
|  |  | ||||||
|         // Write key type |         // Write key type | ||||||
|         const int32_t k_type_i = (int32_t)k_l[il]->type; |         const int32_t k_type_i = (int32_t)k_l[il]->type; | ||||||
|         io.write(&k_type_i, sizeof(k_type_i)); |         io.write(&k_type_i, sizeof(k_type_i)); | ||||||
|  |  | ||||||
|         // Write row size of key |         // Write row size of key | ||||||
|         const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); |         const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); | ||||||
|         io.write(&k_size_row, sizeof(k_size_row)); |         io.write(&k_size_row, sizeof(k_size_row)); | ||||||
|  |  | ||||||
|         // Read each range of cells of k_size length each into tmp_buf and write out |         // Read each range of cells of k_size length each into tmp_buf and write out | ||||||
| @@ -774,14 +770,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | |||||||
|  |  | ||||||
|     if (!v_trans) { |     if (!v_trans) { | ||||||
|         for (uint32_t il = 0; il < n_layer; ++il) { |         for (uint32_t il = 0; il < n_layer; ++il) { | ||||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); |  | ||||||
|  |  | ||||||
|             // Write value type |             // Write value type | ||||||
|             const int32_t v_type_i = (int32_t)v_l[il]->type; |             const int32_t v_type_i = (int32_t)v_l[il]->type; | ||||||
|             io.write(&v_type_i, sizeof(v_type_i)); |             io.write(&v_type_i, sizeof(v_type_i)); | ||||||
|  |  | ||||||
|             // Write row size of value |             // Write row size of value | ||||||
|             const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); |             const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); | ||||||
|             io.write(&v_size_row, sizeof(v_size_row)); |             io.write(&v_size_row, sizeof(v_size_row)); | ||||||
|  |  | ||||||
|             // Read each range of cells of v_size length each into tmp_buf and write out |             // Read each range of cells of v_size length each into tmp_buf and write out | ||||||
| @@ -795,7 +790,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | |||||||
|         // When v is transposed, we also need the element size and get the element ranges from each row |         // When v is transposed, we also need the element size and get the element ranges from each row | ||||||
|         const uint32_t kv_size = size; |         const uint32_t kv_size = size; | ||||||
|         for (uint32_t il = 0; il < n_layer; ++il) { |         for (uint32_t il = 0; il < n_layer; ++il) { | ||||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); |             const uint32_t n_embd_v_s = hparams.n_embd_v_s(); | ||||||
|  |  | ||||||
|             // Write value type |             // Write value type | ||||||
|             const int32_t v_type_i = (int32_t)v_l[il]->type; |             const int32_t v_type_i = (int32_t)v_l[il]->type; | ||||||
| @@ -806,10 +801,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | |||||||
|             io.write(&v_size_el, sizeof(v_size_el)); |             io.write(&v_size_el, sizeof(v_size_el)); | ||||||
|  |  | ||||||
|             // Write GQA embedding size |             // Write GQA embedding size | ||||||
|             io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); |             io.write(&n_embd_v_s, sizeof(n_embd_v_s)); | ||||||
|  |  | ||||||
|             // For each row, we get the element values of each cell |             // For each row, we get the element values of each cell | ||||||
|             for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { |             for (uint32_t j = 0; j < n_embd_v_s; ++j) { | ||||||
|                 // Read each range of cells of v_size_el length each into tmp_buf and write out |                 // Read each range of cells of v_size_el length each into tmp_buf and write out | ||||||
|                 for (const auto & range : cell_ranges) { |                 for (const auto & range : cell_ranges) { | ||||||
|                     const size_t range_size = range.second - range.first; |                     const size_t range_size = range.second - range.first; | ||||||
| @@ -942,7 +937,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||
|  |  | ||||||
|     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block |     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block | ||||||
|     for (uint32_t il = 0; il < n_layer; ++il) { |     for (uint32_t il = 0; il < n_layer; ++il) { | ||||||
|         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); |  | ||||||
|  |  | ||||||
|         // Read type of key |         // Read type of key | ||||||
|         int32_t k_type_i_ref; |         int32_t k_type_i_ref; | ||||||
| @@ -956,7 +950,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||
|         // Read row size of key |         // Read row size of key | ||||||
|         uint64_t k_size_row_ref; |         uint64_t k_size_row_ref; | ||||||
|         io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); |         io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); | ||||||
|         const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); |         const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); | ||||||
|         if (k_size_row != k_size_row_ref) { |         if (k_size_row != k_size_row_ref) { | ||||||
|             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); |             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); | ||||||
|             return false; |             return false; | ||||||
| @@ -970,7 +964,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||
|  |  | ||||||
|     if (!v_trans) { |     if (!v_trans) { | ||||||
|         for (uint32_t il = 0; il < n_layer; ++il) { |         for (uint32_t il = 0; il < n_layer; ++il) { | ||||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); |  | ||||||
|  |  | ||||||
|             // Read type of value |             // Read type of value | ||||||
|             int32_t v_type_i_ref; |             int32_t v_type_i_ref; | ||||||
| @@ -984,7 +977,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||
|             // Read row size of value |             // Read row size of value | ||||||
|             uint64_t v_size_row_ref; |             uint64_t v_size_row_ref; | ||||||
|             io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); |             io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); | ||||||
|             const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); |             const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); | ||||||
|             if (v_size_row != v_size_row_ref) { |             if (v_size_row != v_size_row_ref) { | ||||||
|                 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); |                 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); | ||||||
|                 return false; |                 return false; | ||||||
| @@ -998,7 +991,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||
|     } else { |     } else { | ||||||
|         // For each layer, read the values for each cell (transposed) |         // For each layer, read the values for each cell (transposed) | ||||||
|         for (uint32_t il = 0; il < n_layer; ++il) { |         for (uint32_t il = 0; il < n_layer; ++il) { | ||||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); |             const uint32_t n_embd_v_s = hparams.n_embd_v_s(); | ||||||
|  |  | ||||||
|             // Read type of value |             // Read type of value | ||||||
|             int32_t v_type_i_ref; |             int32_t v_type_i_ref; | ||||||
| @@ -1018,17 +1011,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // Read GQA embedding size |             // Read state embedding size | ||||||
|             uint32_t n_embd_v_gqa_ref; |             uint32_t n_embd_v_s_ref; | ||||||
|             io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); |             io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref)); | ||||||
|             if (n_embd_v_gqa != n_embd_v_gqa_ref) { |             if (n_embd_v_s != n_embd_v_s_ref) { | ||||||
|                 LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); |                 LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il); | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (cell_count) { |             if (cell_count) { | ||||||
|                 // For each row in the transposed matrix, read the values for the whole cell range |                 // For each row in the transposed matrix, read the values for the whole cell range | ||||||
|                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { |                 for (uint32_t j = 0; j < n_embd_v_s; ++j) { | ||||||
|                     const size_t dst_offset = (head + j * size) * v_size_el; |                     const size_t dst_offset = (head + j * size) * v_size_el; | ||||||
|                     ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); |                     ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); | ||||||
|                 } |                 } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Gabe Goodhart
					Gabe Goodhart