mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	refactor: Remove n_embd_k/v_gqa from recurrent cache
This is no longer needed now that there are separate implementations https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
		| @@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); | ||||
|         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); | ||||
|  | ||||
|         const char * dev_name = "CPU"; | ||||
|  | ||||
|         ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); | ||||
| @@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( | ||||
|             throw std::runtime_error("failed to create ggml context for kv cache"); | ||||
|         } | ||||
|  | ||||
|         ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); | ||||
|         ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); | ||||
|         ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size); | ||||
|         ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size); | ||||
|         ggml_format_name(k, "cache_k_l%d", i); | ||||
|         ggml_format_name(v, "cache_v_l%d", i); | ||||
|         k_l[i] = k; | ||||
| @@ -754,14 +751,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | ||||
|     // Iterate and write all the keys first, each row is a cell | ||||
|     // Get whole range at a time | ||||
|     for (uint32_t il = 0; il < n_layer; ++il) { | ||||
|         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); | ||||
|  | ||||
|         // Write key type | ||||
|         const int32_t k_type_i = (int32_t)k_l[il]->type; | ||||
|         io.write(&k_type_i, sizeof(k_type_i)); | ||||
|  | ||||
|         // Write row size of key | ||||
|         const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); | ||||
|         const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); | ||||
|         io.write(&k_size_row, sizeof(k_size_row)); | ||||
|  | ||||
|         // Read each range of cells of k_size length each into tmp_buf and write out | ||||
| @@ -774,14 +770,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | ||||
|  | ||||
|     if (!v_trans) { | ||||
|         for (uint32_t il = 0; il < n_layer; ++il) { | ||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); | ||||
|  | ||||
|             // Write value type | ||||
|             const int32_t v_type_i = (int32_t)v_l[il]->type; | ||||
|             io.write(&v_type_i, sizeof(v_type_i)); | ||||
|  | ||||
|             // Write row size of value | ||||
|             const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); | ||||
|             const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); | ||||
|             io.write(&v_size_row, sizeof(v_size_row)); | ||||
|  | ||||
|             // Read each range of cells of v_size length each into tmp_buf and write out | ||||
| @@ -795,7 +790,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | ||||
|         // When v is transposed, we also need the element size and get the element ranges from each row | ||||
|         const uint32_t kv_size = size; | ||||
|         for (uint32_t il = 0; il < n_layer; ++il) { | ||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); | ||||
|             const uint32_t n_embd_v_s = hparams.n_embd_v_s(); | ||||
|  | ||||
|             // Write value type | ||||
|             const int32_t v_type_i = (int32_t)v_l[il]->type; | ||||
| @@ -806,10 +801,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std | ||||
|             io.write(&v_size_el, sizeof(v_size_el)); | ||||
|  | ||||
|             // Write GQA embedding size | ||||
|             io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); | ||||
|             io.write(&n_embd_v_s, sizeof(n_embd_v_s)); | ||||
|  | ||||
|             // For each row, we get the element values of each cell | ||||
|             for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { | ||||
|             for (uint32_t j = 0; j < n_embd_v_s; ++j) { | ||||
|                 // Read each range of cells of v_size_el length each into tmp_buf and write out | ||||
|                 for (const auto & range : cell_ranges) { | ||||
|                     const size_t range_size = range.second - range.first; | ||||
| @@ -942,7 +937,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | ||||
|  | ||||
|     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block | ||||
|     for (uint32_t il = 0; il < n_layer; ++il) { | ||||
|         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); | ||||
|  | ||||
|         // Read type of key | ||||
|         int32_t k_type_i_ref; | ||||
| @@ -956,7 +950,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | ||||
|         // Read row size of key | ||||
|         uint64_t k_size_row_ref; | ||||
|         io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); | ||||
|         const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); | ||||
|         const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); | ||||
|         if (k_size_row != k_size_row_ref) { | ||||
|             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); | ||||
|             return false; | ||||
| @@ -970,7 +964,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | ||||
|  | ||||
|     if (!v_trans) { | ||||
|         for (uint32_t il = 0; il < n_layer; ++il) { | ||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); | ||||
|  | ||||
|             // Read type of value | ||||
|             int32_t v_type_i_ref; | ||||
| @@ -984,7 +977,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | ||||
|             // Read row size of value | ||||
|             uint64_t v_size_row_ref; | ||||
|             io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); | ||||
|             const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); | ||||
|             const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); | ||||
|             if (v_size_row != v_size_row_ref) { | ||||
|                 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); | ||||
|                 return false; | ||||
| @@ -998,7 +991,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | ||||
|     } else { | ||||
|         // For each layer, read the values for each cell (transposed) | ||||
|         for (uint32_t il = 0; il < n_layer; ++il) { | ||||
|             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); | ||||
|             const uint32_t n_embd_v_s = hparams.n_embd_v_s(); | ||||
|  | ||||
|             // Read type of value | ||||
|             int32_t v_type_i_ref; | ||||
| @@ -1018,17 +1011,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | ||||
|                 return false; | ||||
|             } | ||||
|  | ||||
|             // Read GQA embedding size | ||||
|             uint32_t n_embd_v_gqa_ref; | ||||
|             io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); | ||||
|             if (n_embd_v_gqa != n_embd_v_gqa_ref) { | ||||
|                 LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); | ||||
|             // Read state embedding size | ||||
|             uint32_t n_embd_v_s_ref; | ||||
|             io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref)); | ||||
|             if (n_embd_v_s != n_embd_v_s_ref) { | ||||
|                 LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il); | ||||
|                 return false; | ||||
|             } | ||||
|  | ||||
|             if (cell_count) { | ||||
|                 // For each row in the transposed matrix, read the values for the whole cell range | ||||
|                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { | ||||
|                 for (uint32_t j = 0; j < n_embd_v_s; ++j) { | ||||
|                     const size_t dst_offset = (head + j * size) * v_size_el; | ||||
|                     ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); | ||||
|                 } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gabe Goodhart
					Gabe Goodhart