mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
convert : more robust default ftype detection
This commit is contained in:
@@ -63,6 +63,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
||||
@dataclass
|
||||
class ModelTensorInfo:
|
||||
load: Callable[[], Tensor]
|
||||
size: int # in elements
|
||||
src_type: str
|
||||
src_qtype: gguf.GGMLQuantizationType | None = None
|
||||
dst_qtype: gguf.GGMLQuantizationType | None = None
|
||||
@@ -76,6 +77,7 @@ class ModelBase:
|
||||
|
||||
dir_model: Path
|
||||
ftype: gguf.LlamaFileType
|
||||
ftype_guessed: bool
|
||||
fname_out: Path
|
||||
is_big_endian: bool
|
||||
endianess: gguf.GGUFEndian
|
||||
@@ -116,6 +118,7 @@ class ModelBase:
|
||||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.ftype_guessed = ftype == gguf.LlamaFileType.GUESSED
|
||||
self.fname_out = fname_out
|
||||
self.is_big_endian = is_big_endian
|
||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
@@ -132,6 +135,34 @@ class ModelBase:
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
# find out the most common type
|
||||
hist: dict[gguf.GGMLQuantizationType, int] = {}
|
||||
for t in self.model_tensors.values():
|
||||
if t.dst_qtype is not None:
|
||||
if t.dst_qtype not in hist:
|
||||
hist[t.dst_qtype] = 0
|
||||
hist[t.dst_qtype] += t.size
|
||||
max_qtype = gguf.GGMLQuantizationType.F32
|
||||
max_size = 0
|
||||
for qtype, size in hist.items():
|
||||
if size > max_size:
|
||||
max_qtype = qtype
|
||||
max_size = size
|
||||
# TODO: add more type if they're used as dst_qtypes
|
||||
if max_qtype == gguf.GGMLQuantizationType.F32:
|
||||
self.ftype = gguf.LlamaFileType.ALL_F32
|
||||
elif max_qtype == gguf.GGMLQuantizationType.F16:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_F16
|
||||
elif max_qtype == gguf.GGMLQuantizationType.BF16:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
elif max_qtype == gguf.GGMLQuantizationType.Q8_0:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_Q8_0
|
||||
elif max_qtype == gguf.GGMLQuantizationType.Q4_1:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_Q4_1
|
||||
elif max_qtype == gguf.GGMLQuantizationType.TQ1_0:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_TQ1_0
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
|
||||
@@ -167,6 +198,7 @@ class ModelBase:
|
||||
qtype = LazyTorchTensor._qtype_map.get(dtype)
|
||||
tensors[name] = ModelTensorInfo(
|
||||
load=lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r),
|
||||
size=math.prod(remote_tensor.shape),
|
||||
src_type=str(dtype),
|
||||
src_qtype=qtype,
|
||||
dst_qtype=qtype,
|
||||
@@ -215,12 +247,14 @@ class ModelBase:
|
||||
if is_safetensors:
|
||||
data: gguf.utility.LocalTensor = model_part[name]
|
||||
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
|
||||
size = math.prod(data.shape)
|
||||
if self.lazy:
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
|
||||
else:
|
||||
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
|
||||
else:
|
||||
data_torch: Tensor = model_part[name]
|
||||
size = data_torch.numel()
|
||||
dtype = data_torch.dtype
|
||||
if self.lazy:
|
||||
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
@@ -229,6 +263,7 @@ class ModelBase:
|
||||
qtype = LazyTorchTensor._qtype_map.get(dtype)
|
||||
tensors[name] = ModelTensorInfo(
|
||||
load=data_gen,
|
||||
size=size,
|
||||
src_type=str(dtype),
|
||||
src_qtype=qtype,
|
||||
dst_qtype=qtype,
|
||||
@@ -333,6 +368,7 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = ModelTensorInfo(
|
||||
load=lambda w=w, s=s: dequant_bitnet(w.load(), s.load()),
|
||||
size=w.size,
|
||||
src_type="bitnet",
|
||||
src_qtype=gguf.GGMLQuantizationType.F32,
|
||||
dst_qtype=gguf.GGMLQuantizationType.TQ1_0,
|
||||
@@ -346,6 +382,7 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = ModelTensorInfo(
|
||||
load=lambda w=w, s=s: dequant_simple(w.load(), s.load()),
|
||||
size=w.size,
|
||||
src_type=w.src_type,
|
||||
src_qtype=gguf.GGMLQuantizationType.F32,
|
||||
dst_qtype=gguf.GGMLQuantizationType.BF16, # TODO: change to FP8 once natively supported
|
||||
@@ -364,6 +401,7 @@ class ModelBase:
|
||||
load=lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g.load(), w.load(), z.load(), s.load()
|
||||
),
|
||||
size=qweight.size, # TODO: use more accurate value
|
||||
src_type=f"GPTQ-{bits}bit",
|
||||
src_qtype=gguf.GGMLQuantizationType.F32,
|
||||
dst_qtype=gguf.GGMLQuantizationType.Q8_0 if bits == 8 else gguf.GGMLQuantizationType.Q4_1,
|
||||
@@ -530,7 +568,9 @@ class ModelBase:
|
||||
|
||||
# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
|
||||
if isinstance(data_qtype, bool):
|
||||
if self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
if self.ftype_guessed:
|
||||
data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype
|
||||
elif self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
@@ -542,8 +582,6 @@ class ModelBase:
|
||||
data_qtype = gguf.GGMLQuantizationType.TQ1_0
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.TQ2_0
|
||||
elif self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype
|
||||
else:
|
||||
raise ValueError(f"Unknown file type: {self.ftype.name}")
|
||||
|
||||
@@ -707,23 +745,8 @@ class TextModel(ModelBase):
|
||||
|
||||
total_params = self.gguf_writer.get_total_parameter_count()[0]
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
# TODO: get type name from `quantization_config` field when present?
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||
_, first_tensor = next(self.get_tensors())
|
||||
logger.info(f"first tensor type is {first_tensor.dtype}")
|
||||
if first_tensor.dtype == torch.float16:
|
||||
ftype = gguf.LlamaFileType.MOSTLY_F16
|
||||
elif first_tensor.dtype == torch.bfloat16:
|
||||
ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
else:
|
||||
ftype = gguf.LlamaFileType.ALL_F32
|
||||
else:
|
||||
ftype = self.ftype
|
||||
|
||||
# Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
|
||||
output_type: str = ftype.name.partition("_")[2]
|
||||
output_type: str = self.ftype.name.partition("_")[2]
|
||||
|
||||
# Filename Output
|
||||
if self.fname_out.is_dir():
|
||||
|
||||
Reference in New Issue
Block a user