mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
		| @@ -818,6 +818,21 @@ class TextModel(ModelBase): | ||||
|         if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": | ||||
|             # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct | ||||
|             res = "hunyuan" | ||||
|         if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf": | ||||
|             # ref: https://huggingface.co/skt/A.X-4.0 | ||||
|             res = "a.x-4.0" | ||||
|         if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": | ||||
|             # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base | ||||
|             res = "falcon-h1" | ||||
|         if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86": | ||||
|             # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base | ||||
|             res = "falcon-h1" | ||||
|         if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896": | ||||
|             # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base | ||||
|             res = "falcon-h1" | ||||
|         if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b": | ||||
|             # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base | ||||
|             res = "falcon-h1" | ||||
|  | ||||
|         if res is None: | ||||
|             logger.warning("\n") | ||||
| @@ -4899,17 +4914,19 @@ class Mamba2Model(TextModel): | ||||
|     def set_gguf_parameters(self): | ||||
|         d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) | ||||
|         d_conv  = self.find_hparam(["conv_kernel",       "d_conv"],  optional=True) or 4 | ||||
|         d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||||
|         d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||||
|         d_state = self.find_hparam(["state_size",        "d_state"], optional=True) or 128 | ||||
|         head_dim = self.find_hparam(["head_dim"],                    optional=True) or 64 | ||||
|         head_dim = self.find_hparam(["mamba_d_head", "head_dim"],    optional=True) or 64 | ||||
|         n_group = self.find_hparam(["n_groups"],                     optional=True) or 1 | ||||
|  | ||||
|         rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 | ||||
|  | ||||
|         # Fail early for models which don't have a block expansion factor of 2 | ||||
|         # TODO: does this really matter? | ||||
|         assert d_inner == 2 * d_model | ||||
|         assert d_inner % head_dim == 0 | ||||
|         # skip the assertion for FalconH1 Model | ||||
|         if self.model_arch != gguf.MODEL_ARCH.FALCON_H1: | ||||
|             assert d_inner == 2 * d_model | ||||
|             assert d_inner % head_dim == 0 | ||||
|  | ||||
|         self.gguf_writer.add_context_length(2**20)  # arbitrary value; for those who use the default | ||||
|         self.gguf_writer.add_embedding_length(d_model) | ||||
| @@ -4946,7 +4963,7 @@ class Mamba2Model(TextModel): | ||||
|             data_torch = data_torch.reshape((*data_torch.shape, 1)) | ||||
|         elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): | ||||
|             d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) | ||||
|             d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||||
|             d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||||
|             n_group = self.hparams.get("n_groups", 1) | ||||
|             data_torch = data_torch.reshape((n_group, d_inner // n_group)) | ||||
|  | ||||
| @@ -6656,6 +6673,113 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel): | ||||
|         self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) | ||||
|  | ||||
|  | ||||
| @ModelBase.register("FalconH1ForCausalLM") | ||||
| class FalconH1Model(Mamba2Model): | ||||
|     model_arch = gguf.MODEL_ARCH.FALCON_H1 | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         # Set the hparam prefixes for Falcon Mamba2 | ||||
|         self.hparam_prefixes = ["mamba"] | ||||
|  | ||||
|         # Initialize the base Mamba2Model | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|         # Use Llama conversion for attention | ||||
|         self._transformer_model_class = LlamaModel | ||||
|  | ||||
|         # n_group and d_inner are used during reshape_tensors for mamaba2 | ||||
|         self.n_group = self.find_hparam(["n_groups"]) | ||||
|         self.d_inner = self.find_hparam(["mamba_d_ssm"]) | ||||
|         self.d_head = self.find_hparam(["d_head"]) | ||||
|  | ||||
|         # Initialize any Falcon Mamba2 specific attributes | ||||
|         self.has_attention = True  # Falcon Mamba2 has attention components | ||||
|  | ||||
|         # Load Falcon-H1 multipliers from hyperparameters | ||||
|         self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True) | ||||
|         self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True) | ||||
|         self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True) | ||||
|         self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True) | ||||
|         self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True) | ||||
|         self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True) | ||||
|         self.intermediate_size = self.find_hparam(["intermediate_size"]) | ||||
|         self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True) | ||||
|  | ||||
|     def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: | ||||
|         prefixed = [] | ||||
|         for pfx in self.hparam_prefixes: | ||||
|             prefixed.extend( | ||||
|                 "_".join([pfx, k]) | ||||
|                 for k in keys | ||||
|             ) | ||||
|         keys = list(keys) + prefixed | ||||
|         return super().find_hparam(keys, *args, **kwargs) | ||||
|  | ||||
|     def set_vocab(self): | ||||
|         self._set_vocab_gpt2() | ||||
|  | ||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||
|         tensors = list(super().modify_tensors(data_torch, name, bid)) | ||||
|         tensor = tensors[0][1] | ||||
|  | ||||
|         if "down_proj" in name: | ||||
|             tensor = tensor  * self.mlp_multipliers[1] | ||||
|         elif "gate_proj" in name: | ||||
|             tensor = tensor * self.mlp_multipliers[0] | ||||
|         elif "k_proj" in name: | ||||
|             tensor = tensor * self.key_multiplier * self.attention_in_multiplier | ||||
|         elif "q_proj" in name: | ||||
|             tensor = tensor * self.attention_in_multiplier | ||||
|         elif "v_proj" in name: | ||||
|             tensor = tensor * self.attention_in_multiplier | ||||
|         elif "o_proj" in name: | ||||
|             tensor = tensor * self.attention_out_multiplier | ||||
|         elif "out_proj" in name: | ||||
|             tensor = tensor * self.ssm_out_multiplier | ||||
|         elif "in_proj" in name: | ||||
|             tensor = tensor * self.ssm_in_multiplier | ||||
|             zxbcdt_multipliers = self.hparams["ssm_multipliers"] | ||||
|             intermediate_size = self.hparams["mamba_d_ssm"] | ||||
|             groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] | ||||
|             tensor[:intermediate_size, :] *= zxbcdt_multipliers[0] | ||||
|             tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1] | ||||
|             tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2] | ||||
|             tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3] | ||||
|             tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4] | ||||
|         elif "lm_head" in name: | ||||
|             tensor = tensor * self.hparams["lm_head_multiplier"] | ||||
|         elif "embed_tokens" in name: | ||||
|             tensor = tensor * self.hparams["embedding_multiplier"] | ||||
|         elif "mamba.norm" in name: | ||||
|             tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group) | ||||
|  | ||||
|         tensors = [(tensors[0][0], tensor)] | ||||
|         return tensors | ||||
|  | ||||
|     def set_gguf_parameters(self): | ||||
|         super().set_gguf_parameters() | ||||
|  | ||||
|         ## General Params ## | ||||
|         self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) | ||||
|         # Override some Mamba2 defaults | ||||
|         self.gguf_writer.add_block_count(self.block_count) | ||||
|         self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) | ||||
|         self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) | ||||
|  | ||||
|         ## Attention params ## | ||||
|         self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2 | ||||
|         self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) | ||||
|         self.gguf_writer.add_key_length(self.hparams["head_dim"]) | ||||
|         self.gguf_writer.add_value_length(self.hparams["head_dim"]) | ||||
|  | ||||
|         ## Validation ## | ||||
|         assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" | ||||
|         assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}" | ||||
|  | ||||
|         # Add any other Falcon Mamba2 specific configuration | ||||
|         self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) | ||||
|  | ||||
|  | ||||
| @ModelBase.register("HunYuanMoEV1ForCausalLM") | ||||
| class HunYuanMoEModel(TextModel): | ||||
|     model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE | ||||
| @@ -6809,6 +6933,16 @@ class HunYuanMoEModel(TextModel): | ||||
| class SmolLM3Model(LlamaModel): | ||||
|     model_arch = gguf.MODEL_ARCH.SMOLLM3 | ||||
|  | ||||
|     def set_vocab(self): | ||||
|         super().set_vocab() | ||||
|         # remove unsupported array slicing in chat template | ||||
|         # ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1 | ||||
|         from transformers import AutoTokenizer | ||||
|         tokenizer = AutoTokenizer.from_pretrained(self.dir_model) | ||||
|         if tokenizer.chat_template is not None: | ||||
|             chat_template = tokenizer.chat_template.replace("[:]", "") | ||||
|             self.gguf_writer.add_chat_template(chat_template) | ||||
|  | ||||
| ###### CONVERSION LOGIC ###### | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin