diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8556c26825..e2a8d1f56b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -63,6 +63,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]") @dataclass class ModelTensorInfo: load: Callable[[], Tensor] + size: int # in elements src_type: str src_qtype: gguf.GGMLQuantizationType | None = None dst_qtype: gguf.GGMLQuantizationType | None = None @@ -76,6 +77,7 @@ class ModelBase: dir_model: Path ftype: gguf.LlamaFileType + ftype_guessed: bool fname_out: Path is_big_endian: bool endianess: gguf.GGUFEndian @@ -116,6 +118,7 @@ class ModelBase: self.dir_model = dir_model self.ftype = ftype + self.ftype_guessed = ftype == gguf.LlamaFileType.GUESSED self.fname_out = fname_out self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE @@ -132,6 +135,34 @@ class ModelBase: self.dequant_model() + if self.ftype == gguf.LlamaFileType.GUESSED: + # find out the most common type + hist: dict[gguf.GGMLQuantizationType, int] = {} + for t in self.model_tensors.values(): + if t.dst_qtype is not None: + if t.dst_qtype not in hist: + hist[t.dst_qtype] = 0 + hist[t.dst_qtype] += t.size + max_qtype = gguf.GGMLQuantizationType.F32 + max_size = 0 + for qtype, size in hist.items(): + if size > max_size: + max_qtype = qtype + max_size = size + # TODO: add more type if they're used as dst_qtypes + if max_qtype == gguf.GGMLQuantizationType.F32: + self.ftype = gguf.LlamaFileType.ALL_F32 + elif max_qtype == gguf.GGMLQuantizationType.F16: + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + elif max_qtype == gguf.GGMLQuantizationType.BF16: + self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + elif max_qtype == gguf.GGMLQuantizationType.Q8_0: + self.ftype = gguf.LlamaFileType.MOSTLY_Q8_0 + elif max_qtype == gguf.GGMLQuantizationType.Q4_1: + self.ftype = gguf.LlamaFileType.MOSTLY_Q4_1 + elif max_qtype == gguf.GGMLQuantizationType.TQ1_0: + self.ftype = gguf.LlamaFileType.MOSTLY_TQ1_0 + # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard, @@ -167,6 +198,7 @@ class ModelBase: qtype = LazyTorchTensor._qtype_map.get(dtype) tensors[name] = ModelTensorInfo( load=lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r), + size=math.prod(remote_tensor.shape), src_type=str(dtype), src_qtype=qtype, dst_qtype=qtype, @@ -215,12 +247,14 @@ class ModelBase: if is_safetensors: data: gguf.utility.LocalTensor = model_part[name] dtype = LazyTorchTensor._dtype_str_map[data.dtype] + size = math.prod(data.shape) if self.lazy: data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731 else: data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731 else: data_torch: Tensor = model_part[name] + size = data_torch.numel() dtype = data_torch.dtype if self.lazy: data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731 @@ -229,6 +263,7 @@ class ModelBase: qtype = LazyTorchTensor._qtype_map.get(dtype) tensors[name] = ModelTensorInfo( load=data_gen, + size=size, src_type=str(dtype), src_qtype=qtype, dst_qtype=qtype, @@ -333,6 +368,7 @@ class ModelBase: s = self.model_tensors[name] self.model_tensors[weight_name] = ModelTensorInfo( load=lambda w=w, s=s: dequant_bitnet(w.load(), s.load()), + size=w.size, src_type="bitnet", src_qtype=gguf.GGMLQuantizationType.F32, dst_qtype=gguf.GGMLQuantizationType.TQ1_0, @@ -346,6 +382,7 @@ class ModelBase: s = self.model_tensors[name] self.model_tensors[weight_name] = ModelTensorInfo( load=lambda w=w, s=s: dequant_simple(w.load(), s.load()), + size=w.size, src_type=w.src_type, src_qtype=gguf.GGMLQuantizationType.F32, dst_qtype=gguf.GGMLQuantizationType.BF16, # TODO: change to FP8 once natively supported @@ -364,6 +401,7 @@ class ModelBase: load=lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq( g.load(), w.load(), z.load(), s.load() ), + size=qweight.size, # TODO: use more accurate value src_type=f"GPTQ-{bits}bit", src_qtype=gguf.GGMLQuantizationType.F32, dst_qtype=gguf.GGMLQuantizationType.Q8_0 if bits == 8 else gguf.GGMLQuantizationType.Q4_1, @@ -530,7 +568,9 @@ class ModelBase: # No override (data_qtype is False), or wants to be quantized (data_qtype is True) if isinstance(data_qtype, bool): - if self.ftype == gguf.LlamaFileType.ALL_F32: + if self.ftype_guessed: + data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype + elif self.ftype == gguf.LlamaFileType.ALL_F32: data_qtype = gguf.GGMLQuantizationType.F32 elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: data_qtype = gguf.GGMLQuantizationType.F16 @@ -542,8 +582,6 @@ class ModelBase: data_qtype = gguf.GGMLQuantizationType.TQ1_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: data_qtype = gguf.GGMLQuantizationType.TQ2_0 - elif self.ftype == gguf.LlamaFileType.GUESSED: - data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype else: raise ValueError(f"Unknown file type: {self.ftype.name}") @@ -707,23 +745,8 @@ class TextModel(ModelBase): total_params = self.gguf_writer.get_total_parameter_count()[0] - # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type - # TODO: get type name from `quantization_config` field when present? - if self.ftype == gguf.LlamaFileType.GUESSED: - # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. - _, first_tensor = next(self.get_tensors()) - logger.info(f"first tensor type is {first_tensor.dtype}") - if first_tensor.dtype == torch.float16: - ftype = gguf.LlamaFileType.MOSTLY_F16 - elif first_tensor.dtype == torch.bfloat16: - ftype = gguf.LlamaFileType.MOSTLY_BF16 - else: - ftype = gguf.LlamaFileType.ALL_F32 - else: - ftype = self.ftype - # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0' - output_type: str = ftype.name.partition("_")[2] + output_type: str = self.ftype.name.partition("_")[2] # Filename Output if self.fname_out.is_dir():