mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
convert : use reflinks for faster conversion
This commit is contained in:
@@ -30,6 +30,7 @@ from .constants import (
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
from .utility import LocalTensorRange, best_alignment_offset, copy_tensor_ranges
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,14 +85,16 @@ class GGUFWriter:
|
||||
|
||||
def __init__(
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
|
||||
use_reflinks = False, # opportunistically attempt to use copy-on-write
|
||||
):
|
||||
self.fout = None
|
||||
self.path = Path(path) if path else None
|
||||
self.arch = arch
|
||||
self.endianess = endianess
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
self.use_temp_file = use_temp_file
|
||||
self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range")
|
||||
self.use_temp_file = use_temp_file if not self.use_reflinks else False
|
||||
self.temp_file = None
|
||||
self.tensors = [{}]
|
||||
self.kv_data = [{}]
|
||||
@@ -107,6 +110,10 @@ class GGUFWriter:
|
||||
if self.small_first_shard:
|
||||
self.tensors.append({})
|
||||
|
||||
if self.use_reflinks:
|
||||
# common default block size for COW filesystems
|
||||
self.add_custom_alignment(4096)
|
||||
|
||||
self.add_architecture()
|
||||
|
||||
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
||||
@@ -257,14 +264,20 @@ class GGUFWriter:
|
||||
offset_tensor = 0
|
||||
|
||||
for name, ti in tensors.items():
|
||||
align_offset = 0
|
||||
if self.use_reflinks:
|
||||
ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
|
||||
if len(ranges) > 0:
|
||||
align_offset = best_alignment_offset(ranges, self.data_alignment)
|
||||
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
for j in range(n_dims):
|
||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
|
||||
ti_data += self._pack("I", ti.dtype)
|
||||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
ti_data += self._pack("Q", offset_tensor + align_offset)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment)
|
||||
|
||||
fout.write(ti_data)
|
||||
fout.flush()
|
||||
@@ -398,6 +411,7 @@ class GGUFWriter:
|
||||
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
||||
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
||||
assert self.fout is not None
|
||||
assert not self.use_reflinks # TODO: handle this here too
|
||||
|
||||
if self.endianess == GGUFEndian.BIG:
|
||||
tensor.byteswap(inplace=True)
|
||||
@@ -450,15 +464,21 @@ class GGUFWriter:
|
||||
shard_bar.reset(total=(total if total > 0 else None))
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in tensors.values():
|
||||
for name, ti in tensors.items():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
ti.tensor.tofile(fout)
|
||||
if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0:
|
||||
logger.debug(f"using reflinks for {name}")
|
||||
start_offset = fout.tell()
|
||||
copy_tensor_ranges(fout, ranges, self.data_alignment)
|
||||
self.write_padding(fout, fout.tell() - start_offset)
|
||||
else:
|
||||
ti.tensor.tofile(fout)
|
||||
self.write_padding(fout, ti.nbytes)
|
||||
if shard_bar is not None:
|
||||
shard_bar.update(ti.nbytes)
|
||||
if bar is not None:
|
||||
bar.update(ti.nbytes)
|
||||
self.write_padding(fout, ti.nbytes)
|
||||
ti.tensor = None
|
||||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
from .utility import LocalTensorRange
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,10 +21,11 @@ class LazyMeta(ABCMeta):
|
||||
return type(self)._wrap_fn(
|
||||
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
|
||||
use_self=self,
|
||||
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze"),
|
||||
)
|
||||
elif isinstance(meta_attr, self._tensor_type):
|
||||
# e.g. self.T with torch.Tensor should still be wrapped
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, name), use_self=self)()
|
||||
else:
|
||||
# no need to wrap non-tensor properties,
|
||||
# and they likely don't depend on the actual contents of the tensor
|
||||
@@ -39,8 +41,9 @@ class LazyMeta(ABCMeta):
|
||||
def wrapped_special_op(self, *args, **kwargs):
|
||||
return type(self)._wrap_fn(
|
||||
getattr(type(self)._tensor_type, op_name),
|
||||
use_self=self,
|
||||
meta_noop=meta_noop,
|
||||
)(self, *args, **kwargs)
|
||||
)(*args, **kwargs)
|
||||
return wrapped_special_op
|
||||
|
||||
# special methods bypass __getattr__, so they need to be added manually
|
||||
@@ -76,14 +79,16 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
_args: tuple
|
||||
_kwargs: dict[str, Any]
|
||||
_func: Callable[[Any], Any] | None
|
||||
_ranges: tuple[LocalTensorRange, ...]
|
||||
|
||||
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
|
||||
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None, ranges: tuple[LocalTensorRange, ...] = ()):
|
||||
super().__init__()
|
||||
self._meta = meta
|
||||
self._data = data
|
||||
self._args = args
|
||||
self._kwargs = kwargs if kwargs is not None else {}
|
||||
self._func = func
|
||||
self._ranges = ranges
|
||||
assert self._func is not None or self._data is not None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
@@ -107,7 +112,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
return o
|
||||
|
||||
@classmethod
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False, data_noop: bool = False) -> Callable[[Any], Any]:
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@@ -116,6 +121,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
||||
# TODO: maybe handle tensors in kwargs too
|
||||
|
||||
ranges = use_self._ranges if use_self is not None and data_noop else ()
|
||||
|
||||
if isinstance(meta_noop, bool) and not meta_noop:
|
||||
try:
|
||||
res = fn(*meta_args, **kwargs)
|
||||
@@ -138,7 +145,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
|
||||
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn, ranges=ranges)
|
||||
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
|
||||
# share the evaluation between lazy tuple elements
|
||||
shared_args: list = [args, None]
|
||||
@@ -214,7 +221,8 @@ class LazyNumpyTensor(LazyBase):
|
||||
def astype(self, dtype, *args, **kwargs):
|
||||
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
||||
full_args = (self, dtype,) + args
|
||||
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
|
||||
ranges = self._ranges if self._meta.dtype == dtype else ()
|
||||
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
|
||||
|
||||
def tofile(self, *args, **kwargs):
|
||||
eager = LazyNumpyTensor.to_eager(self)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from io import BufferedReader, BufferedWriter
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
||||
@@ -281,6 +285,83 @@ class LocalTensorRange:
|
||||
size: int
|
||||
|
||||
|
||||
def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
|
||||
hist: dict[int, int] = {}
|
||||
|
||||
for r in ranges:
|
||||
align_offset = r.offset % alignment
|
||||
if align_offset not in hist:
|
||||
hist[align_offset] = 0
|
||||
hist[align_offset] += r.size
|
||||
|
||||
best_offset = 0
|
||||
best_size = 0
|
||||
for offset, size in hist.items():
|
||||
if size > best_size:
|
||||
best_size = size
|
||||
best_offset = offset
|
||||
return best_offset
|
||||
|
||||
# (assuming this is only called where os.copy_file_range is present)
|
||||
#
|
||||
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
|
||||
# to make it more likely that copy-on-write is used where possible.
|
||||
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
|
||||
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
|
||||
assert len(ranges) > 0
|
||||
dst_offset = fout.tell()
|
||||
assert dst_offset % alignment == 0, dst_offset % alignment
|
||||
align_offset = best_alignment_offset(ranges, alignment)
|
||||
if len(ranges) == 1:
|
||||
r = ranges[0]
|
||||
with open(r.filename, "rb") as src:
|
||||
offset_src = r.offset - align_offset
|
||||
offset_src_end = r.offset + r.size
|
||||
if offset_src_end % alignment != 0:
|
||||
offset_src_end += alignment - (offset_src_end % alignment)
|
||||
size = offset_src_end - offset_src
|
||||
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
|
||||
dst_offset += r.size + align_offset
|
||||
else:
|
||||
# All ranges need to have the same alignment offset
|
||||
# Non-consecutive ranges need a patch block in between when the alignment offset is non-zero
|
||||
src_files: dict[Path, BufferedReader] = {}
|
||||
for r in ranges:
|
||||
if r.filename not in src_files:
|
||||
src_files[r.filename] = open(r.filename, "rb")
|
||||
|
||||
for i, r in enumerate(ranges):
|
||||
this_align_offset = r.offset % alignment
|
||||
src = src_files[r.filename]
|
||||
if this_align_offset != align_offset:
|
||||
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
|
||||
if i > 0 and dst_offset % alignment != 0:
|
||||
# Write the correct data between blocks even when they are non-consecutive
|
||||
extra_size = alignment - (dst_offset % alignment)
|
||||
src.seek(r.offset)
|
||||
buf = src.read(extra_size)
|
||||
fout.seek(dst_offset)
|
||||
fout.write(buf)
|
||||
dst_offset += extra_size
|
||||
assert dst_offset % alignment == 0, dst_offset % alignment
|
||||
offset_src = r.offset + extra_size
|
||||
else:
|
||||
# TODO: is this always correct?
|
||||
offset_src = r.offset - align_offset
|
||||
|
||||
offset_src_end = r.offset + r.size
|
||||
if offset_src_end % alignment != 0:
|
||||
offset_src_end += alignment - (offset_src_end % alignment)
|
||||
size = offset_src_end - offset_src
|
||||
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
|
||||
dst_offset += r.size
|
||||
|
||||
for f in src_files.values():
|
||||
f.close()
|
||||
|
||||
fout.seek(dst_offset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalTensor:
|
||||
dtype: str
|
||||
|
||||
Reference in New Issue
Block a user