mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	kv-cache : rework kv_idxs, support seq_cp
ggml-ci
This commit is contained in:
		| @@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { | ||||
| } | ||||
|  | ||||
| void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { | ||||
|     if (self_kv_idxs) { | ||||
|         mctx->set_input_kv_idxs(self_kv_idxs, ubatch); | ||||
|     if (self_k_idxs) { | ||||
|         mctx->set_input_k_idxs(self_k_idxs, ubatch); | ||||
|     } | ||||
|  | ||||
|     if (self_v_idxs) { | ||||
|         mctx->set_input_v_idxs(self_v_idxs, ubatch); | ||||
|     } | ||||
|  | ||||
|     if (self_kq_mask) { | ||||
| @@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { | ||||
| } | ||||
|  | ||||
| void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { | ||||
|     if (self_kv_idxs) { | ||||
|         mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch); | ||||
|     if (self_k_idxs) { | ||||
|         mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); | ||||
|     } | ||||
|  | ||||
|     if (self_kv_idxs_swa) { | ||||
|         mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch); | ||||
|     if (self_v_idxs) { | ||||
|         mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); | ||||
|     } | ||||
|  | ||||
|     if (self_k_idxs_swa) { | ||||
|         mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); | ||||
|     } | ||||
|  | ||||
|     if (self_v_idxs_swa) { | ||||
|         mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); | ||||
|     } | ||||
|  | ||||
|     if (self_kq_mask) { | ||||
| @@ -1209,8 +1221,8 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() | ||||
|         const auto n_kv   = mctx_cur->get_n_kv(); | ||||
|         const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; | ||||
|  | ||||
|         inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); | ||||
|         ggml_set_input(inp->self_kv_idxs); | ||||
|         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); | ||||
|         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); | ||||
|  | ||||
|         inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         ggml_set_input(inp->self_kq_mask); | ||||
| @@ -1243,10 +1255,11 @@ ggml_tensor * llm_graph_context::build_attn( | ||||
|  | ||||
|     // store to KV cache | ||||
|     { | ||||
|         const auto & kv_idxs = inp->get_kv_idxs(); | ||||
|         const auto & k_idxs = inp->get_k_idxs(); | ||||
|         const auto & v_idxs = inp->get_v_idxs(); | ||||
|  | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il)); | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il)); | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); | ||||
|     } | ||||
|  | ||||
|     const auto & kq_mask = inp->get_kq_mask(); | ||||
| @@ -1299,10 +1312,11 @@ ggml_tensor * llm_graph_context::build_attn( | ||||
|  | ||||
|     // store to KV cache | ||||
|     { | ||||
|         const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs(); | ||||
|         const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); | ||||
|         const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs(); | ||||
|  | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il)); | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il)); | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); | ||||
|         ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); | ||||
|     } | ||||
|  | ||||
|     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); | ||||
| @@ -1444,8 +1458,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif | ||||
|     { | ||||
|         const auto n_kv = mctx_cur->get_base()->get_n_kv(); | ||||
|  | ||||
|         inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); | ||||
|         ggml_set_input(inp->self_kv_idxs); | ||||
|         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); | ||||
|         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); | ||||
|  | ||||
|         inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         ggml_set_input(inp->self_kq_mask); | ||||
| @@ -1458,8 +1472,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif | ||||
|  | ||||
|         const auto n_kv = mctx_cur->get_swa()->get_n_kv(); | ||||
|  | ||||
|         inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); | ||||
|         ggml_set_input(inp->self_kv_idxs_swa); | ||||
|         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); | ||||
|         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); | ||||
|  | ||||
|         inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         ggml_set_input(inp->self_kq_mask_swa); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov