mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	convert : support rope_scaling type and rope_type (#13349)
This commit is contained in:
		| @@ -1388,10 +1388,10 @@ class BaichuanModel(TextModel): | |||||||
|         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) |         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) | ||||||
|         self.gguf_writer.add_file_type(self.ftype) |         self.gguf_writer.add_file_type(self.ftype) | ||||||
|  |  | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|         head_count = self.hparams["num_attention_heads"] |         head_count = self.hparams["num_attention_heads"] | ||||||
| @@ -1512,10 +1512,10 @@ class XverseModel(TextModel): | |||||||
|         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) |         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) | ||||||
|         self.gguf_writer.add_file_type(self.ftype) |         self.gguf_writer.add_file_type(self.ftype) | ||||||
|  |  | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|         del bid  # unused |         del bid  # unused | ||||||
| @@ -1828,10 +1828,10 @@ class LlamaModel(TextModel): | |||||||
|             rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] |             rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] | ||||||
|         self.gguf_writer.add_rope_dimension_count(rope_dim) |         self.gguf_writer.add_rope_dimension_count(rope_dim) | ||||||
|  |  | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def permute(weights: Tensor, n_head: int, n_head_kv: int | None): |     def permute(weights: Tensor, n_head: int, n_head_kv: int | None): | ||||||
| @@ -2206,10 +2206,10 @@ class DeciModel(TextModel): | |||||||
|             rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] |             rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] | ||||||
|         self.gguf_writer.add_rope_dimension_count(rope_dim) |         self.gguf_writer.add_rope_dimension_count(rope_dim) | ||||||
|  |  | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def permute(weights: Tensor, n_head: int, n_head_kv: int | None): |     def permute(weights: Tensor, n_head: int, n_head_kv: int | None): | ||||||
| @@ -2449,8 +2449,8 @@ class MiniCPMModel(TextModel): | |||||||
|         logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] |         logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] | ||||||
|         self.gguf_writer.add_logit_scale(logit_scale) |         self.gguf_writer.add_logit_scale(logit_scale) | ||||||
|         logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}") |         logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}") | ||||||
|         if self.hparams.get("rope_scaling") is not None: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "longrope": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "longrope": | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) | ||||||
|             logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") |             logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") | ||||||
|  |  | ||||||
| @@ -2597,11 +2597,11 @@ class Qwen2Model(TextModel): | |||||||
|     def set_gguf_parameters(self): |     def set_gguf_parameters(self): | ||||||
|         super().set_gguf_parameters() |         super().set_gguf_parameters() | ||||||
|         self._try_set_pooling_type() |         self._try_set_pooling_type() | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "yarn": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) |             self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) | ||||||
|  |  | ||||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|         if self.hf_arch == "Qwen2Model": |         if self.hf_arch == "Qwen2Model": | ||||||
| @@ -2763,11 +2763,11 @@ class Qwen2MoeModel(TextModel): | |||||||
|             logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") |             logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") | ||||||
|         # YaRN is not enabled by default |         # YaRN is not enabled by default | ||||||
|         # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts |         # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "yarn": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) |             self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) | ||||||
|  |  | ||||||
|     _experts: list[dict[str, Tensor]] | None = None |     _experts: list[dict[str, Tensor]] | None = None | ||||||
|  |  | ||||||
| @@ -3035,7 +3035,7 @@ class Phi3MiniModel(TextModel): | |||||||
|  |  | ||||||
|         scale = max_pos_embds / orig_max_pos_embds |         scale = max_pos_embds / orig_max_pos_embds | ||||||
|  |  | ||||||
|         rope_scaling_type = rope_scaling.get('type', '').lower() |         rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower() | ||||||
|         if len(rope_scaling_type) == 0: |         if len(rope_scaling_type) == 0: | ||||||
|             raise KeyError('Missing the required key rope_scaling.type') |             raise KeyError('Missing the required key rope_scaling.type') | ||||||
|  |  | ||||||
| @@ -3347,10 +3347,10 @@ class InternLM2Model(TextModel): | |||||||
|         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) |         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) | ||||||
|         self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) |         self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) | ||||||
|         self.gguf_writer.add_file_type(self.ftype) |         self.gguf_writer.add_file_type(self.ftype) | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|         num_heads = self.hparams["num_attention_heads"] |         num_heads = self.hparams["num_attention_heads"] | ||||||
| @@ -3425,10 +3425,10 @@ class InternLM3Model(TextModel): | |||||||
|             rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] |             rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] | ||||||
|         self.gguf_writer.add_rope_dimension_count(rope_dim) |         self.gguf_writer.add_rope_dimension_count(rope_dim) | ||||||
|  |  | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|         n_head = self.hparams["num_attention_heads"] |         n_head = self.hparams["num_attention_heads"] | ||||||
| @@ -4866,12 +4866,12 @@ class DeepseekV2Model(TextModel): | |||||||
|  |  | ||||||
|         self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) |         self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) | ||||||
|  |  | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "yarn": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) |             self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) | ||||||
|                 self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"]) |             self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) | ||||||
|  |  | ||||||
|     _experts: list[dict[str, Tensor]] | None = None |     _experts: list[dict[str, Tensor]] | None = None | ||||||
|  |  | ||||||
| @@ -5363,11 +5363,11 @@ class Glm4Model(TextModel): | |||||||
|         super().set_gguf_parameters() |         super().set_gguf_parameters() | ||||||
|         rope_dim = self.hparams["head_dim"] |         rope_dim = self.hparams["head_dim"] | ||||||
|         self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) |         self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) | ||||||
|         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if self.hparams["rope_scaling"].get("type") == "yarn": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) |             self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") | @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") | ||||||
| @@ -5600,10 +5600,10 @@ class ExaoneModel(TextModel): | |||||||
|         rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) |         rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) | ||||||
|         rotary_factor = rotary_factor if rotary_factor is not None else 1.0 |         rotary_factor = rotary_factor if rotary_factor is not None else 1.0 | ||||||
|         self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) |         self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) | ||||||
|         if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|             if hparams["rope_scaling"].get("type") == "linear": |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||||
|                 self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|  |  | ||||||
|     def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: |     def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: | ||||||
|         if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): |         if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): | ||||||
| @@ -5706,10 +5706,11 @@ class BailingMoeModel(TextModel): | |||||||
|         rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] |         rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] | ||||||
|  |  | ||||||
|         self.gguf_writer.add_rope_dimension_count(rope_dim) |         self.gguf_writer.add_rope_dimension_count(rope_dim) | ||||||
|         if (self.hparams.get("rope_scaling") or {}).get("type") == "yarn" and "factor" in self.hparams["rope_scaling"]: |         rope_scaling = self.hparams.get("rope_scaling") or {} | ||||||
|  |         if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) | ||||||
|             self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) |             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||||||
|             self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) |             self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) | ||||||
|         else: |         else: | ||||||
|             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) |             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) | ||||||
|         self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) |         self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Sigbjørn Skjæret
					Sigbjørn Skjæret