convert : use reflinks for faster conversion

This commit is contained in:
Francis Couture-Harpin
2025-09-01 20:45:57 -04:00
parent e582f1ac63
commit f7394cdaf4
6 changed files with 266 additions and 60 deletions

View File

@@ -11,6 +11,7 @@ import json
import os
import re
import sys
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
@@ -59,6 +60,14 @@ class ModelType(IntEnum):
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
@dataclass
class ModelTensorInfo:
load: Callable[[], Tensor]
src_type: str
src_qtype: gguf.GGMLQuantizationType | None = None
dst_qtype: gguf.GGMLQuantizationType | None = None
class ModelBase:
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
ModelType.TEXT: {},
@@ -74,7 +83,7 @@ class ModelBase:
lazy: bool
dry_run: bool
hparams: dict[str, Any]
model_tensors: dict[str, Callable[[], Tensor]]
model_tensors: dict[str, ModelTensorInfo]
gguf_writer: gguf.GGUFWriter
model_name: str | None
metadata_override: Path | None
@@ -97,7 +106,8 @@ class ModelBase:
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False):
disable_mistral_community_chat_template: bool = False,
use_reflinks: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@@ -118,22 +128,12 @@ class ModelBase:
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
if first_tensor.dtype == torch.float16:
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_F16
else:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
use_reflinks=use_reflinks)
# Mistral specific
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
@@ -152,8 +152,8 @@ class ModelBase:
return None
raise KeyError(f"could not find any of: {keys}")
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
tensors: dict[str, Callable[[], Tensor]] = {}
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, ModelTensorInfo]:
tensors: dict[str, ModelTensorInfo] = {}
if remote_hf_model_id is not None:
is_safetensors = True
@@ -161,7 +161,14 @@ class ModelBase:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
for name, remote_tensor in remote_tensors.items():
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
dtype = LazyTorchTensor._dtype_str_map[remote_tensor.dtype]
qtype = LazyTorchTensor._qtype_map.get(dtype)
tensors[name] = ModelTensorInfo(
load=lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r),
src_type=str(dtype),
src_qtype=qtype,
dst_qtype=qtype,
)
return tensors
@@ -205,18 +212,25 @@ class ModelBase:
for name in model_part.keys():
if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
else:
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else:
data_torch: Tensor = model_part[name]
dtype = data_torch.dtype
if self.lazy:
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen
qtype = LazyTorchTensor._qtype_map.get(dtype)
tensors[name] = ModelTensorInfo(
load=data_gen,
src_type=str(dtype),
src_qtype=qtype,
dst_qtype=qtype,
)
# verify tensor name presence and identify potentially missing files
if len(tensor_names_from_index) > 0:
@@ -237,7 +251,7 @@ class ModelBase:
def dequant_model(self):
tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}
new_tensors: dict[str, ModelTensorInfo] = {}
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
quant_method = quant_config.get("quant_method")
@@ -315,7 +329,12 @@ class ModelBase:
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
self.model_tensors[weight_name] = ModelTensorInfo(
load=lambda w=w, s=s: dequant_bitnet(w.load(), s.load()),
src_type="bitnet",
src_qtype=gguf.GGMLQuantizationType.F32,
dst_qtype=gguf.GGMLQuantizationType.TQ1_0,
)
tensors_to_remove.append(name)
elif quant_method == "fp8":
for name in self.model_tensors.keys():
@@ -323,9 +342,15 @@ class ModelBase:
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
self.model_tensors[weight_name] = ModelTensorInfo(
load=lambda w=w, s=s: dequant_simple(w.load(), s.load()),
src_type=w.src_type,
src_qtype=gguf.GGMLQuantizationType.F32,
dst_qtype=gguf.GGMLQuantizationType.BF16, # TODO: change to FP8 once natively supported
)
tensors_to_remove.append(name)
elif quant_method == "gptq":
bits = quant_config["bits"]
for name in self.model_tensors.keys():
if name.endswith(".qweight"):
base_name = name.removesuffix(".qweight")
@@ -333,10 +358,13 @@ class ModelBase:
qweight = self.model_tensors[base_name + ".qweight"]
qzeros = self.model_tensors[base_name + ".qzeros"]
scales = self.model_tensors[base_name + ".scales"]
new_tensors[base_name + ".weight"] = (
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
g(), w(), z(), s()
)
new_tensors[base_name + ".weight"] = ModelTensorInfo(
load=lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
g.load(), w.load(), z.load(), s.load()
),
src_type=f"GPTQ-{bits}bit",
src_qtype=gguf.GGMLQuantizationType.F32,
dst_qtype=gguf.GGMLQuantizationType.Q8_0 if bits == 8 else gguf.GGMLQuantizationType.Q4_1,
)
tensors_to_remove += [
base_name + n
@@ -358,8 +386,8 @@ class ModelBase:
self.model_tensors[name] = value
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, gen in self.model_tensors.items():
yield name, gen()
for name, t in self.model_tensors.items():
yield name, t.load()
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -414,10 +442,12 @@ class ModelBase:
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
old_dtype = data_torch.dtype
tensor_info = self.model_tensors.get(name)
old_dtype: str = tensor_info.src_type if tensor_info is not None else str(data_torch.dtype)
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
# TODO: handle pre-quantized tensors for repacking
if data_torch.dtype not in (torch.float16, torch.bfloat16, torch.float32):
data_torch = data_torch.to(torch.float32)
# use the first number-like part of the tensor name as the block id
@@ -428,8 +458,16 @@ class ModelBase:
break
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
# TODO: why do we squeeze here?
# data = data_torch.squeeze().numpy()
old_qtype = LazyTorchTensor._qtype_map[data_torch.dtype]
# workaround BF16 not being supported by Numpy
if data_torch.dtype == torch.bfloat16:
data_torch = data_torch.view(torch.uint8)
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data_torch.shape) == 0:
data_torch = data_torch.reshape(1)
data = data_torch.numpy()
n_dims = len(data.shape)
@@ -500,15 +538,23 @@ class ModelBase:
data_qtype = gguf.GGMLQuantizationType.TQ1_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
data_qtype = gguf.GGMLQuantizationType.TQ2_0
elif self.ftype == gguf.LlamaFileType.GUESSED:
data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")
try:
data = gguf.quants.quantize(data, data_qtype)
except gguf.QuantError as e:
logger.warning("%s, %s", e, "falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
if old_qtype != data_qtype:
if old_qtype not in (
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
):
data = gguf.quants.dequantize(data, old_qtype)
try:
data = gguf.quants.quantize(data, data_qtype)
except gguf.QuantError as e:
logger.warning("%s, %s", e, "falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
@@ -656,8 +702,24 @@ class TextModel(ModelBase):
super().prepare_metadata(vocab_only=vocab_only)
total_params = self.gguf_writer.get_total_parameter_count()[0]
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
# TODO: get type name from `quantization_config` field when present?
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
logger.info(f"first tensor type is {first_tensor.dtype}")
if first_tensor.dtype == torch.float16:
ftype = gguf.LlamaFileType.MOSTLY_F16
elif first_tensor.dtype == torch.bfloat16:
ftype = gguf.LlamaFileType.MOSTLY_BF16
else:
ftype = gguf.LlamaFileType.ALL_F32
else:
ftype = self.ftype
# Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
output_type: str = self.ftype.name.partition("_")[2]
output_type: str = ftype.name.partition("_")[2]
# Filename Output
if self.fname_out.is_dir():
@@ -8840,12 +8902,20 @@ class LazyTorchTensor(gguf.LazyBase):
"F8_E5M2": torch.float8_e5m2,
}
_qtype_map: dict[torch.dtype, gguf.GGMLQuantizationType] = {
torch.float64: gguf.GGMLQuantizationType.F64,
torch.float32: gguf.GGMLQuantizationType.F32,
torch.float16: gguf.GGMLQuantizationType.F16,
torch.bfloat16: gguf.GGMLQuantizationType.BF16,
}
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
args=(self,),
func=(lambda s: s.numpy())
func=(lambda s: s.numpy()),
ranges=self._ranges
)
@classmethod
@@ -8866,7 +8936,7 @@ class LazyTorchTensor(gguf.LazyBase):
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r), ranges=(t.data_range,))
return cast(torch.Tensor, lazy)
@classmethod
@@ -8887,7 +8957,27 @@ class LazyTorchTensor(gguf.LazyBase):
if func is torch.Tensor.numpy:
return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs)
result = cls._wrap_fn(func)(*args, **kwargs)
def get_dim(index: int, key: str = "dim", default: int = 0, args=args, kwargs=kwargs) -> int:
# TODO: handle negative dim
if len(args) > index:
return args[index]
else:
return kwargs.get(key, default)
# Track file ranges
# TODO: handle tensor splits (with torch.split, torch.chunk, and torch.__getitem__)
if isinstance(result, LazyTorchTensor):
if isinstance(args[0], LazyTorchTensor):
if func is torch.Tensor.to and not isinstance(args[1], torch.dtype):
result._ranges = args[0]._ranges
if func is torch.stack and get_dim(1) == 0:
if all(isinstance(t, LazyTorchTensor) and len(t._ranges) > 0 for t in args[0]):
# collect ranges of all stacked tensors
result._ranges = tuple(r for t in args[0] for r in t._ranges)
return result
def parse_args() -> argparse.Namespace:
@@ -8902,8 +8992,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="auto",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for mostly unchanged types",
)
parser.add_argument(
"--bigendian", action="store_true",
@@ -8922,6 +9012,10 @@ def parse_args() -> argparse.Namespace:
"--no-lazy", action="store_true",
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
)
parser.add_argument(
"--reflink", action="store_true",
help="(Experimental) Use copy-on-write reflinks when possible (e.g. on BTRFS, XFS, ZFS, etc.). File alignment and padding will differ compared to not using this option. Should be very fast when source model layout is compatible enough.",
)
parser.add_argument(
"--model-name", type=str, default=None,
help="name of the model",
@@ -9106,7 +9200,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
use_reflinks=args.reflink,
)
if args.vocab_only: