mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : fix Gemma3 SWA KV cache shift (#12373)
* llama : fix Gemma3 SWA KV cache shift ggml-ci * hparams : add comment [no ci]
This commit is contained in:
		| @@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift( | |||||||
|         ggml_tensor * cur, |         ggml_tensor * cur, | ||||||
|         ggml_tensor * shift, |         ggml_tensor * shift, | ||||||
|         ggml_tensor * factors, |         ggml_tensor * factors, | ||||||
|  |               float   freq_base, | ||||||
|  |               float   freq_scale, | ||||||
|         ggml_backend_buffer * bbuf) const { |         ggml_backend_buffer * bbuf) const { | ||||||
|     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; |     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; | ||||||
|     const auto & freq_base  = cparams.rope_freq_base; |  | ||||||
|     const auto & freq_scale = cparams.rope_freq_scale; |  | ||||||
|  |  | ||||||
|     const auto & yarn_ext_factor  = cparams.yarn_ext_factor; |     const auto & yarn_ext_factor  = cparams.yarn_ext_factor; | ||||||
|     const auto & yarn_attn_factor = cparams.yarn_attn_factor; |     const auto & yarn_attn_factor = cparams.yarn_attn_factor; | ||||||
| @@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( | |||||||
|         const int64_t n_head_kv    = hparams.n_head_kv(il); |         const int64_t n_head_kv    = hparams.n_head_kv(il); | ||||||
|         const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); |         const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); | ||||||
|  |  | ||||||
|  |         float freq_base_l  = cparams.rope_freq_base; | ||||||
|  |         float freq_scale_l = cparams.rope_freq_scale; | ||||||
|  |  | ||||||
|  |         // TODO: improve | ||||||
|  |         if (model.arch == LLM_ARCH_GEMMA3) { | ||||||
|  |             const bool is_sliding = hparams.is_sliding(il); | ||||||
|  |  | ||||||
|  |             freq_base_l  = is_sliding ? 10000.0f : cparams.rope_freq_base; | ||||||
|  |             freq_scale_l = is_sliding ? 1.0f     : cparams.rope_freq_scale; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); |         ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); | ||||||
|  |  | ||||||
|         ggml_tensor * k = |         ggml_tensor * k = | ||||||
| @@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( | |||||||
|                 ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), |                 ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), | ||||||
|                 0); |                 0); | ||||||
|  |  | ||||||
|         ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer); |         ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer); | ||||||
|  |  | ||||||
|         ggml_build_forward_expand(gf, cur); |         ggml_build_forward_expand(gf, cur); | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -168,6 +168,8 @@ private: | |||||||
|         ggml_tensor * cur, |         ggml_tensor * cur, | ||||||
|         ggml_tensor * shift, |         ggml_tensor * shift, | ||||||
|         ggml_tensor * factors, |         ggml_tensor * factors, | ||||||
|  |               float   freq_base, | ||||||
|  |               float   freq_scale, | ||||||
|         ggml_backend_buffer * bbuf) const; |         ggml_backend_buffer * bbuf) const; | ||||||
|  |  | ||||||
|     llm_graph_result_ptr build_kv_self_shift( |     llm_graph_result_ptr build_kv_self_shift( | ||||||
|   | |||||||
| @@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn( | |||||||
|         ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); |         ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // TODO: improve |     const bool is_sliding = hparams.is_sliding(il); | ||||||
|     bool is_sliding = false; |  | ||||||
|  |  | ||||||
|     switch (arch) { |  | ||||||
|         case LLM_ARCH_COHERE2: |  | ||||||
|             { |  | ||||||
|                 const int32_t sliding_window_pattern = 4; |  | ||||||
|                 is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); |  | ||||||
|             } break; |  | ||||||
|         case LLM_ARCH_GEMMA2: |  | ||||||
|             { |  | ||||||
|                 const int32_t sliding_window_pattern = 2; |  | ||||||
|                 is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); |  | ||||||
|             } break; |  | ||||||
|         case LLM_ARCH_GEMMA3: |  | ||||||
|             { |  | ||||||
|                 const int32_t sliding_window_pattern = 6; |  | ||||||
|                 is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); |  | ||||||
|             } break; |  | ||||||
|         case LLM_ARCH_PHI3: |  | ||||||
|             { |  | ||||||
|                 is_sliding = hparams.n_swa > 0; |  | ||||||
|             } break; |  | ||||||
|         default: |  | ||||||
|             { |  | ||||||
|                 is_sliding = false; |  | ||||||
|             } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask(); |     const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const { | |||||||
|     // corresponds to Mamba's ssm_states size |     // corresponds to Mamba's ssm_states size | ||||||
|     return ssm_d_state * ssm_d_inner; |     return ssm_d_state * ssm_d_inner; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | bool llama_hparams::is_sliding(uint32_t il) const { | ||||||
|  |     if (il < n_layer) { | ||||||
|  |         return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     GGML_ABORT("fatal error"); | ||||||
|  | } | ||||||
|   | |||||||
| @@ -36,6 +36,7 @@ struct llama_hparams { | |||||||
|     uint32_t n_layer; |     uint32_t n_layer; | ||||||
|     uint32_t n_rot; |     uint32_t n_rot; | ||||||
|     uint32_t n_swa = 0; // sliding window attention (SWA) |     uint32_t n_swa = 0; // sliding window attention (SWA) | ||||||
|  |     uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention | ||||||
|     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads |     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads | ||||||
|     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head |     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head | ||||||
|     uint32_t n_expert = 0; |     uint32_t n_expert = 0; | ||||||
| @@ -133,6 +134,8 @@ struct llama_hparams { | |||||||
|  |  | ||||||
|     // dimension of the recurrent state embeddings |     // dimension of the recurrent state embeddings | ||||||
|     uint32_t n_embd_v_s() const; |     uint32_t n_embd_v_s() const; | ||||||
|  |  | ||||||
|  |     bool is_sliding(uint32_t il) const; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable"); | static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable"); | ||||||
|   | |||||||
| @@ -858,11 +858,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||||||
|         case LLM_ARCH_GEMMA2: |         case LLM_ARCH_GEMMA2: | ||||||
|             { |             { | ||||||
|                 hparams.n_swa = 4096; // default value of gemma 2 |                 hparams.n_swa = 4096; // default value of gemma 2 | ||||||
|  |                 hparams.n_swa_pattern = 2; | ||||||
|  |                 hparams.attn_soft_cap = true; | ||||||
|  |  | ||||||
|                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false); |                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false); | ||||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); |                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||||
|                 ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING,      hparams.f_attn_logit_softcapping, false); |                 ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING,      hparams.f_attn_logit_softcapping, false); | ||||||
|                 ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING,     hparams.f_final_logit_softcapping, false); |                 ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING,     hparams.f_final_logit_softcapping, false); | ||||||
|                 hparams.attn_soft_cap = true; |  | ||||||
|  |  | ||||||
|                 switch (hparams.n_layer) { |                 switch (hparams.n_layer) { | ||||||
|                     case 26: type = LLM_TYPE_2B; break; |                     case 26: type = LLM_TYPE_2B; break; | ||||||
| @@ -873,6 +875,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||||||
|             } break; |             } break; | ||||||
|         case LLM_ARCH_GEMMA3: |         case LLM_ARCH_GEMMA3: | ||||||
|             { |             { | ||||||
|  |                 hparams.n_swa_pattern = 6; | ||||||
|  |  | ||||||
|                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa); |                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa); | ||||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); |                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||||
|  |  | ||||||
| @@ -952,6 +956,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||||||
|             } break; |             } break; | ||||||
|         case LLM_ARCH_COHERE2: |         case LLM_ARCH_COHERE2: | ||||||
|             { |             { | ||||||
|  |                 hparams.n_swa_pattern = 4; | ||||||
|  |  | ||||||
|                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); |                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); | ||||||
|                 ml.get_key(LLM_KV_LOGIT_SCALE,              hparams.f_logit_scale); |                 ml.get_key(LLM_KV_LOGIT_SCALE,              hparams.f_logit_scale); | ||||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps); |                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps); | ||||||
| @@ -7374,12 +7380,8 @@ struct llm_build_gemma3 : public llm_graph_context { | |||||||
|         // TODO: is causal == true correct? might need some changes |         // TODO: is causal == true correct? might need some changes | ||||||
|         auto * inp_attn = build_attn_inp_kv_unified(true, true); |         auto * inp_attn = build_attn_inp_kv_unified(true, true); | ||||||
|  |  | ||||||
|         // "5-to-1 interleaved attention" |  | ||||||
|         // 5 layers of local attention followed by 1 layer of global attention |  | ||||||
|         static const int sliding_window_pattern = 6; |  | ||||||
|  |  | ||||||
|         for (int il = 0; il < n_layer; ++il) { |         for (int il = 0; il < n_layer; ++il) { | ||||||
|             const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); |             const bool is_sliding = hparams.is_sliding(il); | ||||||
|  |  | ||||||
|             const float freq_base_l  = is_sliding ? 10000.0f : freq_base; |             const float freq_base_l  = is_sliding ? 10000.0f : freq_base; | ||||||
|             const float freq_scale_l = is_sliding ? 1.0f     : freq_scale; |             const float freq_scale_l = is_sliding ? 1.0f     : freq_scale; | ||||||
| @@ -7970,13 +7972,8 @@ struct llm_build_cohere2 : public llm_graph_context { | |||||||
|  |  | ||||||
|         auto * inp_attn = build_attn_inp_kv_unified(true, true); |         auto * inp_attn = build_attn_inp_kv_unified(true, true); | ||||||
|  |  | ||||||
|         // sliding window switch pattern |  | ||||||
|         const int32_t sliding_window_pattern = 4; |  | ||||||
|  |  | ||||||
|         for (int il = 0; il < n_layer; ++il) { |         for (int il = 0; il < n_layer; ++il) { | ||||||
|             // three layers sliding window attention (window size 4096) and ROPE |             const bool is_sliding = hparams.is_sliding(il); | ||||||
|             // fourth layer uses global attention without positional embeddings |  | ||||||
|             const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); |  | ||||||
|  |  | ||||||
|             // norm |             // norm | ||||||
|             cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il); |             cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov