mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	gemma : more consistent attention scaling for v2 and v3 (#13951)
* gemma : fix attn scale for 27B * cont : apply scale before attn * cont : consistent attention scaling
This commit is contained in:
		@@ -956,6 +956,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 | 
				
			|||||||
                    case 46: type = LLM_TYPE_27B; break;
 | 
					                    case 46: type = LLM_TYPE_27B; break;
 | 
				
			||||||
                    default: type = LLM_TYPE_UNKNOWN;
 | 
					                    default: type = LLM_TYPE_UNKNOWN;
 | 
				
			||||||
               }
 | 
					               }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
 | 
				
			||||||
 | 
					                hparams.f_attention_scale = type == LLM_TYPE_27B
 | 
				
			||||||
 | 
					                    ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
 | 
				
			||||||
 | 
					                    : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        case LLM_ARCH_GEMMA3:
 | 
					        case LLM_ARCH_GEMMA3:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
@@ -976,6 +981,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 | 
				
			|||||||
                    default: type = LLM_TYPE_UNKNOWN;
 | 
					                    default: type = LLM_TYPE_UNKNOWN;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
 | 
				
			||||||
                hparams.f_attention_scale = type == LLM_TYPE_27B
 | 
					                hparams.f_attention_scale = type == LLM_TYPE_27B
 | 
				
			||||||
                    ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
 | 
					                    ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
 | 
				
			||||||
                    : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
 | 
					                    : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
 | 
				
			||||||
@@ -8484,14 +8490,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
 | 
				
			|||||||
                cb(Kcur, "Kcur", il);
 | 
					                cb(Kcur, "Kcur", il);
 | 
				
			||||||
                cb(Vcur, "Vcur", il);
 | 
					                cb(Vcur, "Vcur", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
 | 
					                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 | 
				
			||||||
                switch (model.type) {
 | 
					 | 
				
			||||||
                    case LLM_TYPE_2B:
 | 
					 | 
				
			||||||
                    case LLM_TYPE_9B:
 | 
					 | 
				
			||||||
                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
 | 
					 | 
				
			||||||
                    default: GGML_ABORT("fatal error");
 | 
					 | 
				
			||||||
                };
 | 
					 | 
				
			||||||
                cb(Qcur, "Qcur_scaled", il);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                cur = build_attn(inp_attn, gf,
 | 
					                cur = build_attn(inp_attn, gf,
 | 
				
			||||||
                        model.layers[il].wo, NULL,
 | 
					                        model.layers[il].wo, NULL,
 | 
				
			||||||
@@ -8632,9 +8631,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
 | 
				
			|||||||
                cb(Kcur, "Kcur", il);
 | 
					                cb(Kcur, "Kcur", il);
 | 
				
			||||||
                cb(Vcur, "Vcur", il);
 | 
					                cb(Vcur, "Vcur", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
 | 
				
			||||||
 | 
					                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                cur = build_attn(inp_attn, gf,
 | 
					                cur = build_attn(inp_attn, gf,
 | 
				
			||||||
                        model.layers[il].wo, NULL,
 | 
					                        model.layers[il].wo, NULL,
 | 
				
			||||||
                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
 | 
					                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cur = build_norm(cur,
 | 
					            cur = build_norm(cur,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user