diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b991848df9..51420f612a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11,6 +11,7 @@ import json import os import re import sys +from dataclasses import dataclass from enum import IntEnum from pathlib import Path from hashlib import sha256 @@ -59,6 +60,14 @@ class ModelType(IntEnum): AnyModel = TypeVar("AnyModel", bound="type[ModelBase]") +@dataclass +class ModelTensorInfo: + load: Callable[[], Tensor] + src_type: str + src_qtype: gguf.GGMLQuantizationType | None = None + dst_qtype: gguf.GGMLQuantizationType | None = None + + class ModelBase: _model_classes: dict[ModelType, dict[str, type[ModelBase]]] = { ModelType.TEXT: {}, @@ -74,7 +83,7 @@ class ModelBase: lazy: bool dry_run: bool hparams: dict[str, Any] - model_tensors: dict[str, Callable[[], Tensor]] + model_tensors: dict[str, ModelTensorInfo] gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None @@ -97,7 +106,8 @@ class ModelBase: metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, - disable_mistral_community_chat_template: bool = False): + disable_mistral_community_chat_template: bool = False, + use_reflinks: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -118,22 +128,12 @@ class ModelBase: self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type - if self.ftype == gguf.LlamaFileType.GUESSED: - # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. - _, first_tensor = next(self.get_tensors()) - if first_tensor.dtype == torch.float16: - logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") - self.ftype = gguf.LlamaFileType.MOSTLY_F16 - else: - logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") - self.ftype = gguf.LlamaFileType.MOSTLY_BF16 - self.dequant_model() # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, - split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) + split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard, + use_reflinks=use_reflinks) # Mistral specific self.disable_mistral_community_chat_template = disable_mistral_community_chat_template @@ -152,8 +152,8 @@ class ModelBase: return None raise KeyError(f"could not find any of: {keys}") - def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: - tensors: dict[str, Callable[[], Tensor]] = {} + def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, ModelTensorInfo]: + tensors: dict[str, ModelTensorInfo] = {} if remote_hf_model_id is not None: is_safetensors = True @@ -161,7 +161,14 @@ class ModelBase: logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) for name, remote_tensor in remote_tensors.items(): - tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r) + dtype = LazyTorchTensor._dtype_str_map[remote_tensor.dtype] + qtype = LazyTorchTensor._qtype_map.get(dtype) + tensors[name] = ModelTensorInfo( + load=lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r), + src_type=str(dtype), + src_qtype=qtype, + dst_qtype=qtype, + ) return tensors @@ -205,18 +212,25 @@ class ModelBase: for name in model_part.keys(): if is_safetensors: data: gguf.utility.LocalTensor = model_part[name] + dtype = LazyTorchTensor._dtype_str_map[data.dtype] if self.lazy: data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731 else: - dtype = LazyTorchTensor._dtype_str_map[data.dtype] data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731 else: data_torch: Tensor = model_part[name] + dtype = data_torch.dtype if self.lazy: data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731 else: data_gen = lambda data=data_torch: data # noqa: E731 - tensors[name] = data_gen + qtype = LazyTorchTensor._qtype_map.get(dtype) + tensors[name] = ModelTensorInfo( + load=data_gen, + src_type=str(dtype), + src_qtype=qtype, + dst_qtype=qtype, + ) # verify tensor name presence and identify potentially missing files if len(tensor_names_from_index) > 0: @@ -237,7 +251,7 @@ class ModelBase: def dequant_model(self): tensors_to_remove: list[str] = [] - new_tensors: dict[str, Callable[[], Tensor]] = {} + new_tensors: dict[str, ModelTensorInfo] = {} if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict): quant_method = quant_config.get("quant_method") @@ -315,7 +329,12 @@ class ModelBase: weight_name = name.removesuffix("_scale") w = self.model_tensors[weight_name] s = self.model_tensors[name] - self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s()) + self.model_tensors[weight_name] = ModelTensorInfo( + load=lambda w=w, s=s: dequant_bitnet(w.load(), s.load()), + src_type="bitnet", + src_qtype=gguf.GGMLQuantizationType.F32, + dst_qtype=gguf.GGMLQuantizationType.TQ1_0, + ) tensors_to_remove.append(name) elif quant_method == "fp8": for name in self.model_tensors.keys(): @@ -323,9 +342,15 @@ class ModelBase: weight_name = name.removesuffix("_scale_inv") w = self.model_tensors[weight_name] s = self.model_tensors[name] - self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s()) + self.model_tensors[weight_name] = ModelTensorInfo( + load=lambda w=w, s=s: dequant_simple(w.load(), s.load()), + src_type=w.src_type, + src_qtype=gguf.GGMLQuantizationType.F32, + dst_qtype=gguf.GGMLQuantizationType.BF16, # TODO: change to FP8 once natively supported + ) tensors_to_remove.append(name) elif quant_method == "gptq": + bits = quant_config["bits"] for name in self.model_tensors.keys(): if name.endswith(".qweight"): base_name = name.removesuffix(".qweight") @@ -333,10 +358,13 @@ class ModelBase: qweight = self.model_tensors[base_name + ".qweight"] qzeros = self.model_tensors[base_name + ".qzeros"] scales = self.model_tensors[base_name + ".scales"] - new_tensors[base_name + ".weight"] = ( - lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq( - g(), w(), z(), s() - ) + new_tensors[base_name + ".weight"] = ModelTensorInfo( + load=lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq( + g.load(), w.load(), z.load(), s.load() + ), + src_type=f"GPTQ-{bits}bit", + src_qtype=gguf.GGMLQuantizationType.F32, + dst_qtype=gguf.GGMLQuantizationType.Q8_0 if bits == 8 else gguf.GGMLQuantizationType.Q4_1, ) tensors_to_remove += [ base_name + n @@ -358,8 +386,8 @@ class ModelBase: self.model_tensors[name] = value def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - for name, gen in self.model_tensors.items(): - yield name, gen() + for name, t in self.model_tensors.items(): + yield name, t.load() def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: @@ -414,10 +442,12 @@ class ModelBase: if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue - old_dtype = data_torch.dtype + tensor_info = self.model_tensors.get(name) + old_dtype: str = tensor_info.src_type if tensor_info is not None else str(data_torch.dtype) # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): + # TODO: handle pre-quantized tensors for repacking + if data_torch.dtype not in (torch.float16, torch.bfloat16, torch.float32): data_torch = data_torch.to(torch.float32) # use the first number-like part of the tensor name as the block id @@ -428,8 +458,16 @@ class ModelBase: break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - # TODO: why do we squeeze here? - # data = data_torch.squeeze().numpy() + old_qtype = LazyTorchTensor._qtype_map[data_torch.dtype] + + # workaround BF16 not being supported by Numpy + if data_torch.dtype == torch.bfloat16: + data_torch = data_torch.view(torch.uint8) + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data_torch.shape) == 0: + data_torch = data_torch.reshape(1) + data = data_torch.numpy() n_dims = len(data.shape) @@ -500,15 +538,23 @@ class ModelBase: data_qtype = gguf.GGMLQuantizationType.TQ1_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: data_qtype = gguf.GGMLQuantizationType.TQ2_0 + elif self.ftype == gguf.LlamaFileType.GUESSED: + data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype else: raise ValueError(f"Unknown file type: {self.ftype.name}") - try: - data = gguf.quants.quantize(data, data_qtype) - except gguf.QuantError as e: - logger.warning("%s, %s", e, "falling back to F16") - data_qtype = gguf.GGMLQuantizationType.F16 - data = gguf.quants.quantize(data, data_qtype) + if old_qtype != data_qtype: + if old_qtype not in ( + gguf.GGMLQuantizationType.F32, + gguf.GGMLQuantizationType.F16, + ): + data = gguf.quants.dequantize(data, old_qtype) + try: + data = gguf.quants.quantize(data, data_qtype) + except gguf.QuantError as e: + logger.warning("%s, %s", e, "falling back to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape @@ -656,8 +702,24 @@ class TextModel(ModelBase): super().prepare_metadata(vocab_only=vocab_only) total_params = self.gguf_writer.get_total_parameter_count()[0] + + # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type + # TODO: get type name from `quantization_config` field when present? + if self.ftype == gguf.LlamaFileType.GUESSED: + # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. + _, first_tensor = next(self.get_tensors()) + logger.info(f"first tensor type is {first_tensor.dtype}") + if first_tensor.dtype == torch.float16: + ftype = gguf.LlamaFileType.MOSTLY_F16 + elif first_tensor.dtype == torch.bfloat16: + ftype = gguf.LlamaFileType.MOSTLY_BF16 + else: + ftype = gguf.LlamaFileType.ALL_F32 + else: + ftype = self.ftype + # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0' - output_type: str = self.ftype.name.partition("_")[2] + output_type: str = ftype.name.partition("_")[2] # Filename Output if self.fname_out.is_dir(): @@ -8840,12 +8902,20 @@ class LazyTorchTensor(gguf.LazyBase): "F8_E5M2": torch.float8_e5m2, } + _qtype_map: dict[torch.dtype, gguf.GGMLQuantizationType] = { + torch.float64: gguf.GGMLQuantizationType.F64, + torch.float32: gguf.GGMLQuantizationType.F32, + torch.float16: gguf.GGMLQuantizationType.F16, + torch.bfloat16: gguf.GGMLQuantizationType.BF16, + } + def numpy(self) -> gguf.LazyNumpyTensor: dtype = self._dtype_map[self.dtype] return gguf.LazyNumpyTensor( meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), args=(self,), - func=(lambda s: s.numpy()) + func=(lambda s: s.numpy()), + ranges=self._ranges ) @classmethod @@ -8866,7 +8936,7 @@ class LazyTorchTensor(gguf.LazyBase): return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) dtype = cls._dtype_str_map[t.dtype] shape = t.shape - lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r), ranges=(t.data_range,)) return cast(torch.Tensor, lazy) @classmethod @@ -8887,7 +8957,27 @@ class LazyTorchTensor(gguf.LazyBase): if func is torch.Tensor.numpy: return args[0].numpy() - return cls._wrap_fn(func)(*args, **kwargs) + result = cls._wrap_fn(func)(*args, **kwargs) + + def get_dim(index: int, key: str = "dim", default: int = 0, args=args, kwargs=kwargs) -> int: + # TODO: handle negative dim + if len(args) > index: + return args[index] + else: + return kwargs.get(key, default) + + # Track file ranges + # TODO: handle tensor splits (with torch.split, torch.chunk, and torch.__getitem__) + if isinstance(result, LazyTorchTensor): + if isinstance(args[0], LazyTorchTensor): + if func is torch.Tensor.to and not isinstance(args[1], torch.dtype): + result._ranges = args[0]._ranges + if func is torch.stack and get_dim(1) == 0: + if all(isinstance(t, LazyTorchTensor) and len(t._ranges) > 0 for t in args[0]): + # collect ranges of all stacked tensors + result._ranges = tuple(r for t in args[0] for r in t._ranges) + + return result def parse_args() -> argparse.Namespace: @@ -8902,8 +8992,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="auto", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for mostly unchanged types", ) parser.add_argument( "--bigendian", action="store_true", @@ -8922,6 +9012,10 @@ def parse_args() -> argparse.Namespace: "--no-lazy", action="store_true", help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", ) + parser.add_argument( + "--reflink", action="store_true", + help="(Experimental) Use copy-on-write reflinks when possible (e.g. on BTRFS, XFS, ZFS, etc.). File alignment and padding will differ compared to not using this option. Should be very fast when source model layout is compatible enough.", + ) parser.add_argument( "--model-name", type=str, default=None, help="name of the model", @@ -9106,7 +9200,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template + remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, + use_reflinks=args.reflink, ) if args.vocab_only: diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 19a7adb2d1..7a50675e2d 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -42,8 +42,8 @@ void ggml_print_backtrace(void); # define MAX(a, b) ((a) > (b) ? (a) : (b)) #endif -// required for mmap as gguf only guarantees 32-byte alignment -#define TENSOR_ALIGNMENT 32 +// required for mmap as gguf converted with reflinks from safetensors only guarantees 8-byte alignment +#define TENSOR_ALIGNMENT 8 // static_assert should be a #define, but if it's not, // fall back to the _Static_assert C11 keyword. diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8cc4ef1cf4..9673bf78ba 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -624,6 +624,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ctx->size = 0; for (size_t i = 0; i < ctx->info.size(); ++i) { const gguf_tensor_info & ti = ctx->info[i]; + // HACK: bypass the continuity check + ctx->size = ti.offset; if (ti.offset != ctx->size) { GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", __func__, ti.t.name, ti.offset, ctx->size); diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a6cc8a931e..03e7ba930b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -30,6 +30,7 @@ from .constants import ( ) from .quants import quant_shape_from_byte_shape +from .utility import LocalTensorRange, best_alignment_offset, copy_tensor_ranges logger = logging.getLogger(__name__) @@ -84,14 +85,16 @@ class GGUFWriter: def __init__( self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, + use_reflinks = False, # opportunistically attempt to use copy-on-write ): self.fout = None self.path = Path(path) if path else None self.arch = arch self.endianess = endianess self.data_alignment = GGUF_DEFAULT_ALIGNMENT - self.use_temp_file = use_temp_file + self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range") + self.use_temp_file = use_temp_file if not self.use_reflinks else False self.temp_file = None self.tensors = [{}] self.kv_data = [{}] @@ -107,6 +110,10 @@ class GGUFWriter: if self.small_first_shard: self.tensors.append({}) + if self.use_reflinks: + # common default block size for COW filesystems + self.add_custom_alignment(4096) + self.add_architecture() def get_total_parameter_count(self) -> tuple[int, int, int, int]: @@ -257,14 +264,20 @@ class GGUFWriter: offset_tensor = 0 for name, ti in tensors.items(): + align_offset = 0 + if self.use_reflinks: + ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ()) + if len(ranges) > 0: + align_offset = best_alignment_offset(ranges, self.data_alignment) + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) n_dims = len(ti.shape) ti_data += self._pack("I", n_dims) for j in range(n_dims): ti_data += self._pack("Q", ti.shape[n_dims - 1 - j]) ti_data += self._pack("I", ti.dtype) - ti_data += self._pack("Q", offset_tensor) - offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + ti_data += self._pack("Q", offset_tensor + align_offset) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment) fout.write(ti_data) fout.flush() @@ -398,6 +411,7 @@ class GGUFWriter: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None + assert not self.use_reflinks # TODO: handle this here too if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) @@ -450,15 +464,21 @@ class GGUFWriter: shard_bar.reset(total=(total if total > 0 else None)) # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in tensors.values(): + for name, ti in tensors.items(): assert ti.tensor is not None # can only iterate once over the tensors assert ti.tensor.nbytes == ti.nbytes - ti.tensor.tofile(fout) + if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0: + logger.debug(f"using reflinks for {name}") + start_offset = fout.tell() + copy_tensor_ranges(fout, ranges, self.data_alignment) + self.write_padding(fout, fout.tell() - start_offset) + else: + ti.tensor.tofile(fout) + self.write_padding(fout, ti.nbytes) if shard_bar is not None: shard_bar.update(ti.nbytes) if bar is not None: bar.update(ti.nbytes) - self.write_padding(fout, ti.nbytes) ti.tensor = None else: self.temp_file.seek(0) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index f9bcadae02..c4e5400639 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -6,6 +6,7 @@ from typing import Any, Callable import numpy as np from numpy.typing import DTypeLike +from .utility import LocalTensorRange logger = logging.getLogger(__name__) @@ -20,10 +21,11 @@ class LazyMeta(ABCMeta): return type(self)._wrap_fn( (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)), use_self=self, + data_noop=name in ("view", "reshape", "squeeze", "unsqueeze"), ) elif isinstance(meta_attr, self._tensor_type): # e.g. self.T with torch.Tensor should still be wrapped - return type(self)._wrap_fn(lambda s: getattr(s, name))(self) + return type(self)._wrap_fn(lambda s: getattr(s, name), use_self=self)() else: # no need to wrap non-tensor properties, # and they likely don't depend on the actual contents of the tensor @@ -39,8 +41,9 @@ class LazyMeta(ABCMeta): def wrapped_special_op(self, *args, **kwargs): return type(self)._wrap_fn( getattr(type(self)._tensor_type, op_name), + use_self=self, meta_noop=meta_noop, - )(self, *args, **kwargs) + )(*args, **kwargs) return wrapped_special_op # special methods bypass __getattr__, so they need to be added manually @@ -76,14 +79,16 @@ class LazyBase(ABC, metaclass=LazyMeta): _args: tuple _kwargs: dict[str, Any] _func: Callable[[Any], Any] | None + _ranges: tuple[LocalTensorRange, ...] - def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None): + def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None, ranges: tuple[LocalTensorRange, ...] = ()): super().__init__() self._meta = meta self._data = data self._args = args self._kwargs = kwargs if kwargs is not None else {} self._func = func + self._ranges = ranges assert self._func is not None or self._data is not None def __init_subclass__(cls) -> None: @@ -107,7 +112,7 @@ class LazyBase(ABC, metaclass=LazyMeta): return o @classmethod - def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]: + def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False, data_noop: bool = False) -> Callable[[Any], Any]: def wrapped_fn(*args, **kwargs): if kwargs is None: kwargs = {} @@ -116,6 +121,8 @@ class LazyBase(ABC, metaclass=LazyMeta): meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) # TODO: maybe handle tensors in kwargs too + ranges = use_self._ranges if use_self is not None and data_noop else () + if isinstance(meta_noop, bool) and not meta_noop: try: res = fn(*meta_args, **kwargs) @@ -138,7 +145,7 @@ class LazyBase(ABC, metaclass=LazyMeta): res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): - return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn) + return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn, ranges=ranges) elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res): # share the evaluation between lazy tuple elements shared_args: list = [args, None] @@ -214,7 +221,8 @@ class LazyNumpyTensor(LazyBase): def astype(self, dtype, *args, **kwargs): meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) full_args = (self, dtype,) + args - return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs))) + ranges = self._ranges if self._meta.dtype == dtype else () + return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges) def tofile(self, *args, **kwargs): eager = LazyNumpyTensor.to_eager(self) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index c9401a1c0a..63c7cc7cae 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,13 +1,17 @@ from __future__ import annotations from dataclasses import dataclass +from io import BufferedReader, BufferedWriter from pathlib import Path from typing import Literal import os import json +import logging import numpy as np +logger = logging.getLogger(__name__) + def fill_templated_filename(filename: str, output_type: str | None) -> str: # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' @@ -281,6 +285,83 @@ class LocalTensorRange: size: int +def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int): + hist: dict[int, int] = {} + + for r in ranges: + align_offset = r.offset % alignment + if align_offset not in hist: + hist[align_offset] = 0 + hist[align_offset] += r.size + + best_offset = 0 + best_size = 0 + for offset, size in hist.items(): + if size > best_size: + best_size = size + best_offset = offset + return best_offset + +# (assuming this is only called where os.copy_file_range is present) +# +# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes +# to make it more likely that copy-on-write is used where possible. +# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too). +def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096): + assert len(ranges) > 0 + dst_offset = fout.tell() + assert dst_offset % alignment == 0, dst_offset % alignment + align_offset = best_alignment_offset(ranges, alignment) + if len(ranges) == 1: + r = ranges[0] + with open(r.filename, "rb") as src: + offset_src = r.offset - align_offset + offset_src_end = r.offset + r.size + if offset_src_end % alignment != 0: + offset_src_end += alignment - (offset_src_end % alignment) + size = offset_src_end - offset_src + os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset) + dst_offset += r.size + align_offset + else: + # All ranges need to have the same alignment offset + # Non-consecutive ranges need a patch block in between when the alignment offset is non-zero + src_files: dict[Path, BufferedReader] = {} + for r in ranges: + if r.filename not in src_files: + src_files[r.filename] = open(r.filename, "rb") + + for i, r in enumerate(ranges): + this_align_offset = r.offset % alignment + src = src_files[r.filename] + if this_align_offset != align_offset: + logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})") + if i > 0 and dst_offset % alignment != 0: + # Write the correct data between blocks even when they are non-consecutive + extra_size = alignment - (dst_offset % alignment) + src.seek(r.offset) + buf = src.read(extra_size) + fout.seek(dst_offset) + fout.write(buf) + dst_offset += extra_size + assert dst_offset % alignment == 0, dst_offset % alignment + offset_src = r.offset + extra_size + else: + # TODO: is this always correct? + offset_src = r.offset - align_offset + + offset_src_end = r.offset + r.size + if offset_src_end % alignment != 0: + offset_src_end += alignment - (offset_src_end % alignment) + size = offset_src_end - offset_src + os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset) + dst_offset += r.size + + for f in src_files.values(): + f.close() + + fout.seek(dst_offset) + + @dataclass class LocalTensor: dtype: str