mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : fix KV shift for qwen2vl (#13870)
* llama : fix KV shift for qwen2vl * add ref to the PR
This commit is contained in:
		| @@ -455,7 +455,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : | |||||||
|     } |     } | ||||||
|  |  | ||||||
| int64_t llm_graph_context::n_pos_per_embd() const { | int64_t llm_graph_context::n_pos_per_embd() const { | ||||||
|     return arch == LLM_ARCH_QWEN2VL ? 4 : 1; |     return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { | void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { | ||||||
|   | |||||||
| @@ -757,11 +757,19 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( | |||||||
|     const auto & yarn_beta_slow  = cparams.yarn_beta_slow; |     const auto & yarn_beta_slow  = cparams.yarn_beta_slow; | ||||||
|  |  | ||||||
|     const auto & n_rot     = hparams.n_rot; |     const auto & n_rot     = hparams.n_rot; | ||||||
|     const auto & rope_type = hparams.rope_type; |     const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE | ||||||
|  |                                 // @ngxson : this is a workaround | ||||||
|  |                                 // for M-RoPE, we want to rotate the whole vector when doing KV shift | ||||||
|  |                                 // a normal RoPE should work, we just need to use the correct ordering | ||||||
|  |                                 // ref: https://github.com/ggml-org/llama.cpp/pull/13870 | ||||||
|  |                                 ? LLAMA_ROPE_TYPE_NEOX | ||||||
|  |                                 : hparams.rope_type; | ||||||
|  |  | ||||||
|     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. |     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. | ||||||
|     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. |     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. | ||||||
|     const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; |     const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 | ||||||
|  |                                     ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) | ||||||
|  |                                     : cparams.yarn_attn_factor; | ||||||
|  |  | ||||||
|     ggml_tensor * tmp; |     ggml_tensor * tmp; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan-Son Nguyen
					Xuan-Son Nguyen