mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	kv-cache : fix non-FA path with virutal sequences
ggml-ci
This commit is contained in:
		| @@ -803,6 +803,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     assert(res.s1 >= res.s0); | ||||
|  | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| @@ -908,13 +910,8 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint | ||||
|  | ||||
|     auto * k = layers[ikv].k; | ||||
|  | ||||
|     assert(sinfo.s1 >= sinfo.s0); | ||||
|  | ||||
|     const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; | ||||
|  | ||||
|     assert(ns > 0); | ||||
|     assert(ns <= n_seq_virt); | ||||
|  | ||||
|     const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size()); | ||||
|  | ||||
|     return ggml_view_4d(ctx, k, | ||||
| @@ -932,9 +929,6 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint | ||||
|  | ||||
|     const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; | ||||
|  | ||||
|     assert(ns > 0); | ||||
|     assert(ns <= n_seq_virt); | ||||
|  | ||||
|     const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size()); | ||||
|  | ||||
|     if (!v_trans) { | ||||
| @@ -967,9 +961,20 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ | ||||
|     k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); | ||||
|  | ||||
|     if (kv_idxs && supports_set_rows) { | ||||
|         k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); | ||||
|         const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; | ||||
|  | ||||
|         return ggml_set_rows(ctx, k, k_cur, kv_idxs); | ||||
|         const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size()); | ||||
|  | ||||
|         ggml_tensor * k_view = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], ns, | ||||
|                 ggml_row_size(k->type, k->ne[0]), | ||||
|                 size_virt, | ||||
|                 size_virt*sinfo.s0); | ||||
|  | ||||
|         k_cur = ggml_reshape_3d(ctx, k_cur, k_cur->ne[0], k_cur->ne[1]/ns, ns); | ||||
|  | ||||
|         kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns); | ||||
|  | ||||
|         return ggml_set_rows(ctx, k_view, k_cur, kv_idxs); | ||||
|     } | ||||
|  | ||||
|     // TODO: fallback to old ggml_cpy() method for backwards compatibility | ||||
| @@ -995,27 +1000,46 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ | ||||
|     v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); | ||||
|  | ||||
|     if (kv_idxs && supports_set_rows) { | ||||
|         if (!v_trans) { | ||||
|             v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); | ||||
|         const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; | ||||
|  | ||||
|             return ggml_set_rows(ctx, v, v_cur, kv_idxs); | ||||
|         const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size()); | ||||
|  | ||||
|         if (!v_trans) { | ||||
|             ggml_tensor * v_view = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], ns, | ||||
|                     ggml_row_size(v->type, v->ne[0]), | ||||
|                     size_virt, | ||||
|                     size_virt*sinfo.s0); | ||||
|  | ||||
|             v_cur = ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], v_cur->ne[1]/ns, ns); | ||||
|  | ||||
|             kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns); | ||||
|  | ||||
|             return ggml_set_rows(ctx, v_view, v_cur, kv_idxs); | ||||
|         } | ||||
|  | ||||
|         // the row becomes a single element | ||||
|         ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]); | ||||
|         ggml_tensor * v_view = ggml_view_4d(ctx, v, 1, v->ne[1], v->ne[0], ns, | ||||
|                 ggml_row_size(v->type, 1), | ||||
|                 ggml_row_size(v->type, v->ne[1]), | ||||
|                 size_virt, | ||||
|                 size_virt*sinfo.s0); | ||||
|  | ||||
|         // note: the V cache is transposed when not using flash attention | ||||
|         v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); | ||||
|         v_cur = ggml_permute(ctx, ggml_reshape_4d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]/ns, ns), 2, 0, 1, 3); | ||||
|  | ||||
|         // note: we can be more explicit here at the cost of extra cont | ||||
|         //       however, above we take advantage that a row of single element is always contiguous regardless of the row stride | ||||
|         //v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns); | ||||
|         //v_cur = ggml_transpose(ctx, v_cur); | ||||
|         //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); | ||||
|         //v_cur = ggml_cont_4d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2]); | ||||
|  | ||||
|         // we broadcast the KV indices n_embd_v_gqa times | ||||
|         // v       [1,        n_kv,     n_embd_v_gqa] | ||||
|         // v_cur   [1,        n_tokens, n_embd_v_gqa] | ||||
|         // kv_idxs [n_tokens, 1,        1] | ||||
|         // v       [1,           n_kv,        n_embd_v_gqa, ns] | ||||
|         // v_cur   [1,           n_tokens/ns, n_embd_v_gqa, ns] | ||||
|         // kv_idxs [n_tokens/ns, 1,           ns] | ||||
|  | ||||
|         kv_idxs = ggml_reshape_3d(ctx, kv_idxs, n_tokens/ns, 1, ns); | ||||
|  | ||||
|         return ggml_set_rows(ctx, v_view, v_cur, kv_idxs); | ||||
|     } | ||||
|  | ||||
| @@ -1053,10 +1077,8 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub | ||||
|     int64_t * data = (int64_t *) dst->data; | ||||
|  | ||||
|     for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { | ||||
|         const int64_t offs = sinfo.seq_id_virt[s]*get_size(); | ||||
|  | ||||
|         for (uint32_t i = 0; i < sinfo.size(); ++i) { | ||||
|             data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; | ||||
|             data[s*sinfo.size() + i] = sinfo.idxs[s][i]; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov