convert : use reflinks for faster conversion

This commit is contained in:
Francis Couture-Harpin
2025-09-01 20:45:57 -04:00
parent e582f1ac63
commit f7394cdaf4
6 changed files with 266 additions and 60 deletions

View File

@@ -30,6 +30,7 @@ from .constants import (
)
from .quants import quant_shape_from_byte_shape
from .utility import LocalTensorRange, best_alignment_offset, copy_tensor_ranges
logger = logging.getLogger(__name__)
@@ -84,14 +85,16 @@ class GGUFWriter:
def __init__(
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
use_reflinks = False, # opportunistically attempt to use copy-on-write
):
self.fout = None
self.path = Path(path) if path else None
self.arch = arch
self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.use_temp_file = use_temp_file
self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range")
self.use_temp_file = use_temp_file if not self.use_reflinks else False
self.temp_file = None
self.tensors = [{}]
self.kv_data = [{}]
@@ -107,6 +110,10 @@ class GGUFWriter:
if self.small_first_shard:
self.tensors.append({})
if self.use_reflinks:
# common default block size for COW filesystems
self.add_custom_alignment(4096)
self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
@@ -257,14 +264,20 @@ class GGUFWriter:
offset_tensor = 0
for name, ti in tensors.items():
align_offset = 0
if self.use_reflinks:
ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
if len(ranges) > 0:
align_offset = best_alignment_offset(ranges, self.data_alignment)
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
n_dims = len(ti.shape)
ti_data += self._pack("I", n_dims)
for j in range(n_dims):
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
ti_data += self._pack("I", ti.dtype)
ti_data += self._pack("Q", offset_tensor)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
ti_data += self._pack("Q", offset_tensor + align_offset)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment)
fout.write(ti_data)
fout.flush()
@@ -398,6 +411,7 @@ class GGUFWriter:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
assert not self.use_reflinks # TODO: handle this here too
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
@@ -450,15 +464,21 @@ class GGUFWriter:
shard_bar.reset(total=(total if total > 0 else None))
# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
for name, ti in tensors.items():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
ti.tensor.tofile(fout)
if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0:
logger.debug(f"using reflinks for {name}")
start_offset = fout.tell()
copy_tensor_ranges(fout, ranges, self.data_alignment)
self.write_padding(fout, fout.tell() - start_offset)
else:
ti.tensor.tofile(fout)
self.write_padding(fout, ti.nbytes)
if shard_bar is not None:
shard_bar.update(ti.nbytes)
if bar is not None:
bar.update(ti.nbytes)
self.write_padding(fout, ti.nbytes)
ti.tensor = None
else:
self.temp_file.seek(0)

View File

@@ -6,6 +6,7 @@ from typing import Any, Callable
import numpy as np
from numpy.typing import DTypeLike
from .utility import LocalTensorRange
logger = logging.getLogger(__name__)
@@ -20,10 +21,11 @@ class LazyMeta(ABCMeta):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze"),
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
return type(self)._wrap_fn(lambda s: getattr(s, name), use_self=self)()
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
@@ -39,8 +41,9 @@ class LazyMeta(ABCMeta):
def wrapped_special_op(self, *args, **kwargs):
return type(self)._wrap_fn(
getattr(type(self)._tensor_type, op_name),
use_self=self,
meta_noop=meta_noop,
)(self, *args, **kwargs)
)(*args, **kwargs)
return wrapped_special_op
# special methods bypass __getattr__, so they need to be added manually
@@ -76,14 +79,16 @@ class LazyBase(ABC, metaclass=LazyMeta):
_args: tuple
_kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
_ranges: tuple[LocalTensorRange, ...]
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None, ranges: tuple[LocalTensorRange, ...] = ()):
super().__init__()
self._meta = meta
self._data = data
self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func
self._ranges = ranges
assert self._func is not None or self._data is not None
def __init_subclass__(cls) -> None:
@@ -107,7 +112,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
return o
@classmethod
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False, data_noop: bool = False) -> Callable[[Any], Any]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
@@ -116,6 +121,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
ranges = use_self._ranges if use_self is not None and data_noop else ()
if isinstance(meta_noop, bool) and not meta_noop:
try:
res = fn(*meta_args, **kwargs)
@@ -138,7 +145,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn, ranges=ranges)
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
# share the evaluation between lazy tuple elements
shared_args: list = [args, None]
@@ -214,7 +221,8 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
ranges = self._ranges if self._meta.dtype == dtype else ()
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)

View File

@@ -1,13 +1,17 @@
from __future__ import annotations
from dataclasses import dataclass
from io import BufferedReader, BufferedWriter
from pathlib import Path
from typing import Literal
import os
import json
import logging
import numpy as np
logger = logging.getLogger(__name__)
def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
@@ -281,6 +285,83 @@ class LocalTensorRange:
size: int
def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
hist: dict[int, int] = {}
for r in ranges:
align_offset = r.offset % alignment
if align_offset not in hist:
hist[align_offset] = 0
hist[align_offset] += r.size
best_offset = 0
best_size = 0
for offset, size in hist.items():
if size > best_size:
best_size = size
best_offset = offset
return best_offset
# (assuming this is only called where os.copy_file_range is present)
#
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
# to make it more likely that copy-on-write is used where possible.
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
assert len(ranges) > 0
dst_offset = fout.tell()
assert dst_offset % alignment == 0, dst_offset % alignment
align_offset = best_alignment_offset(ranges, alignment)
if len(ranges) == 1:
r = ranges[0]
with open(r.filename, "rb") as src:
offset_src = r.offset - align_offset
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size + align_offset
else:
# All ranges need to have the same alignment offset
# Non-consecutive ranges need a patch block in between when the alignment offset is non-zero
src_files: dict[Path, BufferedReader] = {}
for r in ranges:
if r.filename not in src_files:
src_files[r.filename] = open(r.filename, "rb")
for i, r in enumerate(ranges):
this_align_offset = r.offset % alignment
src = src_files[r.filename]
if this_align_offset != align_offset:
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
if i > 0 and dst_offset % alignment != 0:
# Write the correct data between blocks even when they are non-consecutive
extra_size = alignment - (dst_offset % alignment)
src.seek(r.offset)
buf = src.read(extra_size)
fout.seek(dst_offset)
fout.write(buf)
dst_offset += extra_size
assert dst_offset % alignment == 0, dst_offset % alignment
offset_src = r.offset + extra_size
else:
# TODO: is this always correct?
offset_src = r.offset - align_offset
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size
for f in src_files.values():
f.close()
fout.seek(dst_offset)
@dataclass
class LocalTensor:
dtype: str