convert : use reflinks for faster conversion

This commit is contained in:
Francis Couture-Harpin
2025-09-01 20:45:57 -04:00
parent e582f1ac63
commit f7394cdaf4
6 changed files with 266 additions and 60 deletions

View File

@@ -11,6 +11,7 @@ import json
import os
import re
import sys
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
@@ -59,6 +60,14 @@ class ModelType(IntEnum):
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
@dataclass
class ModelTensorInfo:
load: Callable[[], Tensor]
src_type: str
src_qtype: gguf.GGMLQuantizationType | None = None
dst_qtype: gguf.GGMLQuantizationType | None = None
class ModelBase:
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
ModelType.TEXT: {},
@@ -74,7 +83,7 @@ class ModelBase:
lazy: bool
dry_run: bool
hparams: dict[str, Any]
model_tensors: dict[str, Callable[[], Tensor]]
model_tensors: dict[str, ModelTensorInfo]
gguf_writer: gguf.GGUFWriter
model_name: str | None
metadata_override: Path | None
@@ -97,7 +106,8 @@ class ModelBase:
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False):
disable_mistral_community_chat_template: bool = False,
use_reflinks: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@@ -118,22 +128,12 @@ class ModelBase:
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
if first_tensor.dtype == torch.float16:
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_F16
else:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
use_reflinks=use_reflinks)
# Mistral specific
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
@@ -152,8 +152,8 @@ class ModelBase:
return None
raise KeyError(f"could not find any of: {keys}")
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
tensors: dict[str, Callable[[], Tensor]] = {}
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, ModelTensorInfo]:
tensors: dict[str, ModelTensorInfo] = {}
if remote_hf_model_id is not None:
is_safetensors = True
@@ -161,7 +161,14 @@ class ModelBase:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
for name, remote_tensor in remote_tensors.items():
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
dtype = LazyTorchTensor._dtype_str_map[remote_tensor.dtype]
qtype = LazyTorchTensor._qtype_map.get(dtype)
tensors[name] = ModelTensorInfo(
load=lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r),
src_type=str(dtype),
src_qtype=qtype,
dst_qtype=qtype,
)
return tensors
@@ -205,18 +212,25 @@ class ModelBase:
for name in model_part.keys():
if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
else:
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else:
data_torch: Tensor = model_part[name]
dtype = data_torch.dtype
if self.lazy:
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen
qtype = LazyTorchTensor._qtype_map.get(dtype)
tensors[name] = ModelTensorInfo(
load=data_gen,
src_type=str(dtype),
src_qtype=qtype,
dst_qtype=qtype,
)
# verify tensor name presence and identify potentially missing files
if len(tensor_names_from_index) > 0:
@@ -237,7 +251,7 @@ class ModelBase:
def dequant_model(self):
tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}
new_tensors: dict[str, ModelTensorInfo] = {}
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
quant_method = quant_config.get("quant_method")
@@ -315,7 +329,12 @@ class ModelBase:
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
self.model_tensors[weight_name] = ModelTensorInfo(
load=lambda w=w, s=s: dequant_bitnet(w.load(), s.load()),
src_type="bitnet",
src_qtype=gguf.GGMLQuantizationType.F32,
dst_qtype=gguf.GGMLQuantizationType.TQ1_0,
)
tensors_to_remove.append(name)
elif quant_method == "fp8":
for name in self.model_tensors.keys():
@@ -323,9 +342,15 @@ class ModelBase:
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
self.model_tensors[weight_name] = ModelTensorInfo(
load=lambda w=w, s=s: dequant_simple(w.load(), s.load()),
src_type=w.src_type,
src_qtype=gguf.GGMLQuantizationType.F32,
dst_qtype=gguf.GGMLQuantizationType.BF16, # TODO: change to FP8 once natively supported
)
tensors_to_remove.append(name)
elif quant_method == "gptq":
bits = quant_config["bits"]
for name in self.model_tensors.keys():
if name.endswith(".qweight"):
base_name = name.removesuffix(".qweight")
@@ -333,10 +358,13 @@ class ModelBase:
qweight = self.model_tensors[base_name + ".qweight"]
qzeros = self.model_tensors[base_name + ".qzeros"]
scales = self.model_tensors[base_name + ".scales"]
new_tensors[base_name + ".weight"] = (
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
g(), w(), z(), s()
)
new_tensors[base_name + ".weight"] = ModelTensorInfo(
load=lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
g.load(), w.load(), z.load(), s.load()
),
src_type=f"GPTQ-{bits}bit",
src_qtype=gguf.GGMLQuantizationType.F32,
dst_qtype=gguf.GGMLQuantizationType.Q8_0 if bits == 8 else gguf.GGMLQuantizationType.Q4_1,
)
tensors_to_remove += [
base_name + n
@@ -358,8 +386,8 @@ class ModelBase:
self.model_tensors[name] = value
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, gen in self.model_tensors.items():
yield name, gen()
for name, t in self.model_tensors.items():
yield name, t.load()
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -414,10 +442,12 @@ class ModelBase:
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
old_dtype = data_torch.dtype
tensor_info = self.model_tensors.get(name)
old_dtype: str = tensor_info.src_type if tensor_info is not None else str(data_torch.dtype)
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
# TODO: handle pre-quantized tensors for repacking
if data_torch.dtype not in (torch.float16, torch.bfloat16, torch.float32):
data_torch = data_torch.to(torch.float32)
# use the first number-like part of the tensor name as the block id
@@ -428,8 +458,16 @@ class ModelBase:
break
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
# TODO: why do we squeeze here?
# data = data_torch.squeeze().numpy()
old_qtype = LazyTorchTensor._qtype_map[data_torch.dtype]
# workaround BF16 not being supported by Numpy
if data_torch.dtype == torch.bfloat16:
data_torch = data_torch.view(torch.uint8)
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data_torch.shape) == 0:
data_torch = data_torch.reshape(1)
data = data_torch.numpy()
n_dims = len(data.shape)
@@ -500,15 +538,23 @@ class ModelBase:
data_qtype = gguf.GGMLQuantizationType.TQ1_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
data_qtype = gguf.GGMLQuantizationType.TQ2_0
elif self.ftype == gguf.LlamaFileType.GUESSED:
data_qtype = old_qtype if tensor_info is None or tensor_info.dst_qtype is None else tensor_info.dst_qtype
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")
try:
data = gguf.quants.quantize(data, data_qtype)
except gguf.QuantError as e:
logger.warning("%s, %s", e, "falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
if old_qtype != data_qtype:
if old_qtype not in (
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
):
data = gguf.quants.dequantize(data, old_qtype)
try:
data = gguf.quants.quantize(data, data_qtype)
except gguf.QuantError as e:
logger.warning("%s, %s", e, "falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
@@ -656,8 +702,24 @@ class TextModel(ModelBase):
super().prepare_metadata(vocab_only=vocab_only)
total_params = self.gguf_writer.get_total_parameter_count()[0]
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
# TODO: get type name from `quantization_config` field when present?
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
logger.info(f"first tensor type is {first_tensor.dtype}")
if first_tensor.dtype == torch.float16:
ftype = gguf.LlamaFileType.MOSTLY_F16
elif first_tensor.dtype == torch.bfloat16:
ftype = gguf.LlamaFileType.MOSTLY_BF16
else:
ftype = gguf.LlamaFileType.ALL_F32
else:
ftype = self.ftype
# Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
output_type: str = self.ftype.name.partition("_")[2]
output_type: str = ftype.name.partition("_")[2]
# Filename Output
if self.fname_out.is_dir():
@@ -8840,12 +8902,20 @@ class LazyTorchTensor(gguf.LazyBase):
"F8_E5M2": torch.float8_e5m2,
}
_qtype_map: dict[torch.dtype, gguf.GGMLQuantizationType] = {
torch.float64: gguf.GGMLQuantizationType.F64,
torch.float32: gguf.GGMLQuantizationType.F32,
torch.float16: gguf.GGMLQuantizationType.F16,
torch.bfloat16: gguf.GGMLQuantizationType.BF16,
}
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
args=(self,),
func=(lambda s: s.numpy())
func=(lambda s: s.numpy()),
ranges=self._ranges
)
@classmethod
@@ -8866,7 +8936,7 @@ class LazyTorchTensor(gguf.LazyBase):
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r), ranges=(t.data_range,))
return cast(torch.Tensor, lazy)
@classmethod
@@ -8887,7 +8957,27 @@ class LazyTorchTensor(gguf.LazyBase):
if func is torch.Tensor.numpy:
return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs)
result = cls._wrap_fn(func)(*args, **kwargs)
def get_dim(index: int, key: str = "dim", default: int = 0, args=args, kwargs=kwargs) -> int:
# TODO: handle negative dim
if len(args) > index:
return args[index]
else:
return kwargs.get(key, default)
# Track file ranges
# TODO: handle tensor splits (with torch.split, torch.chunk, and torch.__getitem__)
if isinstance(result, LazyTorchTensor):
if isinstance(args[0], LazyTorchTensor):
if func is torch.Tensor.to and not isinstance(args[1], torch.dtype):
result._ranges = args[0]._ranges
if func is torch.stack and get_dim(1) == 0:
if all(isinstance(t, LazyTorchTensor) and len(t._ranges) > 0 for t in args[0]):
# collect ranges of all stacked tensors
result._ranges = tuple(r for t in args[0] for r in t._ranges)
return result
def parse_args() -> argparse.Namespace:
@@ -8902,8 +8992,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="auto",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for mostly unchanged types",
)
parser.add_argument(
"--bigendian", action="store_true",
@@ -8922,6 +9012,10 @@ def parse_args() -> argparse.Namespace:
"--no-lazy", action="store_true",
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
)
parser.add_argument(
"--reflink", action="store_true",
help="(Experimental) Use copy-on-write reflinks when possible (e.g. on BTRFS, XFS, ZFS, etc.). File alignment and padding will differ compared to not using this option. Should be very fast when source model layout is compatible enough.",
)
parser.add_argument(
"--model-name", type=str, default=None,
help="name of the model",
@@ -9106,7 +9200,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
use_reflinks=args.reflink,
)
if args.vocab_only:

View File

@@ -42,8 +42,8 @@ void ggml_print_backtrace(void);
# define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
// required for mmap as gguf only guarantees 32-byte alignment
#define TENSOR_ALIGNMENT 32
// required for mmap as gguf converted with reflinks from safetensors only guarantees 8-byte alignment
#define TENSOR_ALIGNMENT 8
// static_assert should be a #define, but if it's not,
// fall back to the _Static_assert C11 keyword.

View File

@@ -624,6 +624,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ctx->size = 0;
for (size_t i = 0; i < ctx->info.size(); ++i) {
const gguf_tensor_info & ti = ctx->info[i];
// HACK: bypass the continuity check
ctx->size = ti.offset;
if (ti.offset != ctx->size) {
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
__func__, ti.t.name, ti.offset, ctx->size);

View File

@@ -30,6 +30,7 @@ from .constants import (
)
from .quants import quant_shape_from_byte_shape
from .utility import LocalTensorRange, best_alignment_offset, copy_tensor_ranges
logger = logging.getLogger(__name__)
@@ -84,14 +85,16 @@ class GGUFWriter:
def __init__(
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
use_reflinks = False, # opportunistically attempt to use copy-on-write
):
self.fout = None
self.path = Path(path) if path else None
self.arch = arch
self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.use_temp_file = use_temp_file
self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range")
self.use_temp_file = use_temp_file if not self.use_reflinks else False
self.temp_file = None
self.tensors = [{}]
self.kv_data = [{}]
@@ -107,6 +110,10 @@ class GGUFWriter:
if self.small_first_shard:
self.tensors.append({})
if self.use_reflinks:
# common default block size for COW filesystems
self.add_custom_alignment(4096)
self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
@@ -257,14 +264,20 @@ class GGUFWriter:
offset_tensor = 0
for name, ti in tensors.items():
align_offset = 0
if self.use_reflinks:
ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
if len(ranges) > 0:
align_offset = best_alignment_offset(ranges, self.data_alignment)
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
n_dims = len(ti.shape)
ti_data += self._pack("I", n_dims)
for j in range(n_dims):
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
ti_data += self._pack("I", ti.dtype)
ti_data += self._pack("Q", offset_tensor)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
ti_data += self._pack("Q", offset_tensor + align_offset)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment)
fout.write(ti_data)
fout.flush()
@@ -398,6 +411,7 @@ class GGUFWriter:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
assert not self.use_reflinks # TODO: handle this here too
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
@@ -450,15 +464,21 @@ class GGUFWriter:
shard_bar.reset(total=(total if total > 0 else None))
# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
for name, ti in tensors.items():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
ti.tensor.tofile(fout)
if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0:
logger.debug(f"using reflinks for {name}")
start_offset = fout.tell()
copy_tensor_ranges(fout, ranges, self.data_alignment)
self.write_padding(fout, fout.tell() - start_offset)
else:
ti.tensor.tofile(fout)
self.write_padding(fout, ti.nbytes)
if shard_bar is not None:
shard_bar.update(ti.nbytes)
if bar is not None:
bar.update(ti.nbytes)
self.write_padding(fout, ti.nbytes)
ti.tensor = None
else:
self.temp_file.seek(0)

View File

@@ -6,6 +6,7 @@ from typing import Any, Callable
import numpy as np
from numpy.typing import DTypeLike
from .utility import LocalTensorRange
logger = logging.getLogger(__name__)
@@ -20,10 +21,11 @@ class LazyMeta(ABCMeta):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze"),
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
return type(self)._wrap_fn(lambda s: getattr(s, name), use_self=self)()
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
@@ -39,8 +41,9 @@ class LazyMeta(ABCMeta):
def wrapped_special_op(self, *args, **kwargs):
return type(self)._wrap_fn(
getattr(type(self)._tensor_type, op_name),
use_self=self,
meta_noop=meta_noop,
)(self, *args, **kwargs)
)(*args, **kwargs)
return wrapped_special_op
# special methods bypass __getattr__, so they need to be added manually
@@ -76,14 +79,16 @@ class LazyBase(ABC, metaclass=LazyMeta):
_args: tuple
_kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
_ranges: tuple[LocalTensorRange, ...]
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None, ranges: tuple[LocalTensorRange, ...] = ()):
super().__init__()
self._meta = meta
self._data = data
self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func
self._ranges = ranges
assert self._func is not None or self._data is not None
def __init_subclass__(cls) -> None:
@@ -107,7 +112,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
return o
@classmethod
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False, data_noop: bool = False) -> Callable[[Any], Any]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
@@ -116,6 +121,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
ranges = use_self._ranges if use_self is not None and data_noop else ()
if isinstance(meta_noop, bool) and not meta_noop:
try:
res = fn(*meta_args, **kwargs)
@@ -138,7 +145,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn, ranges=ranges)
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
# share the evaluation between lazy tuple elements
shared_args: list = [args, None]
@@ -214,7 +221,8 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
ranges = self._ranges if self._meta.dtype == dtype else ()
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)

View File

@@ -1,13 +1,17 @@
from __future__ import annotations
from dataclasses import dataclass
from io import BufferedReader, BufferedWriter
from pathlib import Path
from typing import Literal
import os
import json
import logging
import numpy as np
logger = logging.getLogger(__name__)
def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
@@ -281,6 +285,83 @@ class LocalTensorRange:
size: int
def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
hist: dict[int, int] = {}
for r in ranges:
align_offset = r.offset % alignment
if align_offset not in hist:
hist[align_offset] = 0
hist[align_offset] += r.size
best_offset = 0
best_size = 0
for offset, size in hist.items():
if size > best_size:
best_size = size
best_offset = offset
return best_offset
# (assuming this is only called where os.copy_file_range is present)
#
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
# to make it more likely that copy-on-write is used where possible.
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
assert len(ranges) > 0
dst_offset = fout.tell()
assert dst_offset % alignment == 0, dst_offset % alignment
align_offset = best_alignment_offset(ranges, alignment)
if len(ranges) == 1:
r = ranges[0]
with open(r.filename, "rb") as src:
offset_src = r.offset - align_offset
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size + align_offset
else:
# All ranges need to have the same alignment offset
# Non-consecutive ranges need a patch block in between when the alignment offset is non-zero
src_files: dict[Path, BufferedReader] = {}
for r in ranges:
if r.filename not in src_files:
src_files[r.filename] = open(r.filename, "rb")
for i, r in enumerate(ranges):
this_align_offset = r.offset % alignment
src = src_files[r.filename]
if this_align_offset != align_offset:
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
if i > 0 and dst_offset % alignment != 0:
# Write the correct data between blocks even when they are non-consecutive
extra_size = alignment - (dst_offset % alignment)
src.seek(r.offset)
buf = src.read(extra_size)
fout.seek(dst_offset)
fout.write(buf)
dst_offset += extra_size
assert dst_offset % alignment == 0, dst_offset % alignment
offset_src = r.offset + extra_size
else:
# TODO: is this always correct?
offset_src = r.offset - align_offset
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size
for f in src_files.values():
f.close()
fout.seek(dst_offset)
@dataclass
class LocalTensor:
dtype: str