mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	graph : support iSWA virtual sequences
ggml-ci
This commit is contained in:
		| @@ -1001,10 +1001,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { | ||||
|     { | ||||
|         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); | ||||
|  | ||||
|         const auto n_kv = inp->mctx->get_attn()->get_n_kv(); | ||||
|         const auto n_kv   = inp->mctx->get_attn()->get_n_kv(); | ||||
|         const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; | ||||
|  | ||||
|         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||||
|         //cb(inp->self_kq_mask, "KQ_mask", -1); | ||||
|         inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         ggml_set_input(inp->self_kq_mask); | ||||
|  | ||||
|         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; | ||||
| @@ -1206,14 +1206,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() | ||||
|     { | ||||
|         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); | ||||
|  | ||||
|         const auto n_kv = mctx_cur->get_n_kv(); | ||||
|         const auto n_kv   = mctx_cur->get_n_kv(); | ||||
|         const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; | ||||
|  | ||||
|         inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); | ||||
|         ggml_set_input(inp->self_kv_idxs); | ||||
|  | ||||
|         inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         //cb(inp->self_kq_mask, "KQ_mask", -1); | ||||
|         ggml_set_input(inp->self_kq_mask); | ||||
|  | ||||
|         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; | ||||
| @@ -1440,14 +1439,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif | ||||
|  | ||||
|     auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur); | ||||
|  | ||||
|     const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; | ||||
|  | ||||
|     { | ||||
|         const auto n_kv = mctx_cur->get_base()->get_n_kv(); | ||||
|  | ||||
|         inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); | ||||
|         ggml_set_input(inp->self_kv_idxs); | ||||
|  | ||||
|         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||||
|         //cb(inp->self_kq_mask, "KQ_mask", -1); | ||||
|         inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         ggml_set_input(inp->self_kq_mask); | ||||
|  | ||||
|         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; | ||||
| @@ -1461,8 +1461,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif | ||||
|         inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); | ||||
|         ggml_set_input(inp->self_kv_idxs_swa); | ||||
|  | ||||
|         inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||||
|         //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); | ||||
|         inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); | ||||
|         ggml_set_input(inp->self_kq_mask_swa); | ||||
|  | ||||
|         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov