From 6792f66a9329e50c730559f616aae71b034e0087 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 4 Sep 2025 17:40:11 -0400 Subject: [PATCH] convert : detect filesystem block size for reflinks * convert : use direct copies when possible Using os.copy_file_range where available, and falling back to shutil.copyfileobj otherwise. * gguf : handle misaligned offset more cleanly --- convert_hf_to_gguf.py | 6 +- ggml/src/gguf.cpp | 10 +-- gguf-py/gguf/gguf_writer.py | 40 ++++----- gguf-py/gguf/lazy.py | 12 ++- gguf-py/gguf/utility.py | 168 ++++++++++++++++++++---------------- 5 files changed, 132 insertions(+), 104 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 66ef7b591b..f14eef1452 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -80,6 +80,7 @@ class ModelBase: is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool + use_reflinks: bool lazy: bool dry_run: bool hparams: dict[str, Any] @@ -119,6 +120,7 @@ class ModelBase: self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file + self.use_reflinks = use_reflinks self.lazy = not eager or (remote_hf_model_id is not None) self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id @@ -133,7 +135,7 @@ class ModelBase: # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard, - use_reflinks=use_reflinks) + use_reflinks=self.use_reflinks) # Mistral specific self.disable_mistral_community_chat_template = disable_mistral_community_chat_template @@ -202,7 +204,7 @@ class ModelBase: logger.info(f"gguf: indexing model part '{part_name}'") ctx: ContextManager[Any] if is_safetensors: - ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name)) + ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name, reflink=self.use_reflinks)) else: ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 9673bf78ba..167dce3f2a 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -624,16 +624,16 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ctx->size = 0; for (size_t i = 0; i < ctx->info.size(); ++i) { const gguf_tensor_info & ti = ctx->info[i]; - // HACK: bypass the continuity check - ctx->size = ti.offset; - if (ti.offset != ctx->size) { + // alignment offset is only necessary for GGUF converted with reflinks + const size_t align_offset = ti.offset % ctx->alignment; + if (ti.offset - align_offset != ctx->size) { GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", - __func__, ti.t.name, ti.offset, ctx->size); + __func__, ti.t.name, ti.offset, ctx->size + align_offset); GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__); gguf_free(ctx); return nullptr; } - size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment); + size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t) + align_offset, ctx->alignment); if (SIZE_MAX - ctx->size < padded_size) { GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n", __func__, ti.t.name, ctx->size, padded_size); diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 075b381c59..5258fa868c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -30,7 +30,7 @@ from .constants import ( ) from .quants import quant_shape_from_byte_shape -from .utility import LocalTensorRange, best_alignment_offset, reflink_tensor_ranges +from .utility import LocalTensorRange, best_extra_offset logger = logging.getLogger(__name__) @@ -94,7 +94,7 @@ class GGUFWriter: self.endianess = endianess self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range") - self.use_temp_file = use_temp_file if not self.use_reflinks else False + self.use_temp_file = False if self.use_reflinks else use_temp_file self.temp_file = None self.tensors = [{}] self.kv_data = [{}] @@ -110,10 +110,6 @@ class GGUFWriter: if self.small_first_shard: self.tensors.append({}) - if self.use_reflinks: - # common default block size for COW filesystems - self.add_custom_alignment(4096) - self.add_architecture() def get_total_parameter_count(self) -> tuple[int, int, int, int]: @@ -185,6 +181,15 @@ class GGUFWriter: self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY + if self.use_reflinks: + # reflinks require alignment to the filesystem blocks + block_size = os.stat(self.path.parent).st_blksize + # necessary to get an appropriate data start offset + # when padding for reflinks; + # using the real alignment (8 bytes, from safetensors) + # would result in a unusable base data offset + self.add_custom_alignment(block_size) + def print_plan(self) -> list[Path]: logger.info("Writing the following files:") assert self.path is not None @@ -264,11 +269,11 @@ class GGUFWriter: offset_tensor = 0 for name, ti in tensors.items(): - align_offset = 0 + extra_offset = 0 if self.use_reflinks: ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ()) if len(ranges) > 0: - align_offset = best_alignment_offset(ranges, self.data_alignment) + extra_offset = best_extra_offset(ranges, offset_tensor) ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) n_dims = len(ti.shape) @@ -276,8 +281,8 @@ class GGUFWriter: for j in range(n_dims): ti_data += self._pack("Q", ti.shape[n_dims - 1 - j]) ti_data += self._pack("I", ti.dtype) - ti_data += self._pack("Q", offset_tensor + align_offset) - offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment) + ti_data += self._pack("Q", offset_tensor + extra_offset) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + extra_offset, self.data_alignment) fout.write(ti_data) fout.flush() @@ -405,13 +410,12 @@ class GGUFWriter: def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None: pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n if pad != 0: - fp.write(bytes([0] * pad)) + fp.write(b"\x00" * pad) def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - assert not self.use_reflinks # TODO: handle this here too if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) @@ -432,7 +436,7 @@ class GGUFWriter: self.write_padding(fout, fout.tell()) tensor.tofile(fout) - self.write_padding(fout, tensor.nbytes) + self.write_padding(fout, fout.tell()) self.state = WriterState.WEIGHTS @@ -467,18 +471,14 @@ class GGUFWriter: for name, ti in tensors.items(): assert ti.tensor is not None # can only iterate once over the tensors assert ti.tensor.nbytes == ti.nbytes - if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0: + if self.use_reflinks and len(getattr(ti.tensor, "_ranges", ())) > 0: logger.debug(f"using reflinks for {name}") - start_offset = fout.tell() - reflink_tensor_ranges(fout, ranges, self.data_alignment) - self.write_padding(fout, fout.tell() - start_offset) - else: - ti.tensor.tofile(fout) - self.write_padding(fout, ti.nbytes) + ti.tensor.tofile(fout) if shard_bar is not None: shard_bar.update(ti.nbytes) if bar is not None: bar.update(ti.nbytes) + self.write_padding(fout, fout.tell()) ti.tensor = None else: self.temp_file.seek(0) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index 70ffb8d3b7..91214191a8 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -1,12 +1,13 @@ from __future__ import annotations from abc import ABC, ABCMeta, abstractmethod +from io import BufferedWriter import logging from typing import Any, Callable import numpy as np from numpy.typing import DTypeLike -from .utility import LocalTensorRange +from .utility import LocalTensorRange, copy_tensor_ranges logger = logging.getLogger(__name__) @@ -224,8 +225,11 @@ class LazyNumpyTensor(LazyBase): ranges = self._ranges if self._meta.dtype == dtype else () return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges) - def tofile(self, *args, **kwargs): - eager = LazyNumpyTensor.to_eager(self) - return eager.tofile(*args, **kwargs) + def tofile(self, fid, *args, **kwargs): + if isinstance(fid, BufferedWriter) and len(self._ranges) > 0: + return copy_tensor_ranges(fid, self._ranges) + else: + eager = LazyNumpyTensor.to_eager(self) + return eager.tofile(fid, *args, **kwargs) # TODO: __array_function__ diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 90bd4d48b7..78a621e976 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -7,6 +7,7 @@ from typing import Literal import os import json +import shutil import logging import numpy as np @@ -281,99 +282,116 @@ class SafetensorRemote: @dataclass class LocalTensorRange: filename: Path + block_size: int offset: int size: int -def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int): +def best_extra_offset(ranges: tuple[LocalTensorRange, ...], current_offset: int) -> int: hist: dict[int, int] = {} + max_block_size = 0 for r in ranges: - align_offset = r.offset % alignment - if align_offset not in hist: - hist[align_offset] = 0 - hist[align_offset] += r.size + # Ensure minimal alignment is 8 bytes (common with safetensors) + # and that the block size is valid + if r.offset % 8 == 0 and r.block_size > 0: + align_offset = r.offset % r.block_size + if align_offset not in hist: + hist[align_offset] = 0 + hist[align_offset] += r.size + if r.block_size > max_block_size: + max_block_size = r.block_size best_offset = 0 best_size = 0 for offset, size in hist.items(): - # Ensure minimal alignment is 8-bytes (common with safetensors) - if size > best_size and offset % 8 == 0: + if size > best_size: best_size = size best_offset = offset + + if max_block_size > 0: + # the offset needs to be aligned properly + # or else there's probably a block size mismatch + assert current_offset % max_block_size == 0, current_offset % max_block_size + return best_offset -# (assuming this is only called where os.copy_file_range is present) -# # Copy tensor ranges using os.copy_file_range with aligned offsets and sizes # to make it more likely that copy-on-write is used where possible. # Block alignment is necessary for BTRFS and XFS (and likely for ZFS too). -def reflink_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096): +# +# Falls back to shutil.copyfileobj when os.copy_file_range is not present. +def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...]): assert len(ranges) > 0 dst_offset = fout.tell() - assert dst_offset % alignment == 0, dst_offset % alignment - align_offset = best_alignment_offset(ranges, alignment) - if len(ranges) == 1: - r = ranges[0] - with open(r.filename, "rb") as src: - offset_src = r.offset - align_offset - offset_src_end = r.offset + r.size - if offset_src_end % alignment != 0: - offset_src_end += alignment - (offset_src_end % alignment) - size = offset_src_end - offset_src - os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset) - dst_offset += r.size + align_offset - else: - # All ranges need to have the same alignment offset - # Non-consecutive ranges need a patch block in between when the alignment offset is non-zero - src_files: dict[Path, BufferedReader] = {} - for r in ranges: - if r.filename not in src_files: - src_files[r.filename] = open(r.filename, "rb") + extra_offset = best_extra_offset(ranges, dst_offset) - for i, r in enumerate(ranges): - this_align_offset = r.offset % alignment - src = src_files[r.filename] - if this_align_offset != align_offset: - logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})") - # relying on os.copy_file_range to fallback to a non-aligned copy + if extra_offset > 0: + # initial padding + fout.write(b"\x00" * extra_offset) - # Block 0, 1, 2, 3, 4, - # |___0000|0000000|0001111|1111111|111____| - # - # 1. blocks 0, 1 and 2 are copied from range[0] using os.copy_file_range - # 2. block 2 is partially overwritten with contents from range[1] - # 3. blocks 3 and 4 are copied from range[1] using os.copy_file_range - # - # (2 and 3 are repeated with further blocks if there are more ranges) - if i == 0: - extra_size = -align_offset - elif dst_offset % alignment == 0: - extra_size = 0 + dst_offset += extra_offset + start_offset = dst_offset + + src_files: dict[Path, BufferedReader] = {} + for r in ranges: + if r.filename not in src_files: + src_files[r.filename] = open(r.filename, "rb") + + has_copy_file_range = hasattr(os, "copy_file_range") + + for i, r in enumerate(ranges): + src = src_files[r.filename] + if has_copy_file_range: + if r.block_size > 0 and (r.offset % r.block_size) == (start_offset % r.block_size): + # Attempting to align copies for reflinking + + # Block 0, 1, 2, 3, 4, + # |___0000|0000000|0001111|1111111|111____| + # + # 1. block 0 is partially overwritten with contents from range[0] + # 2. blocks 1 and 2 are copied from range[0] using os.copy_file_range + # 3. block 2 is partially overwritten with contents from range[1] + # 4. blocks 3 and 4 are copied from range[1] using os.copy_file_range + # (repeated for further ranges) + if dst_offset % r.block_size == 0: + extra_size = 0 + else: + extra_size = r.block_size - (dst_offset % r.block_size) + extra_size = min(extra_size, r.size) + src.seek(r.offset) + buf = src.read(extra_size) + fout.seek(dst_offset) + fout.write(buf) + dst_offset += extra_size + if extra_size == r.size: + continue + + assert dst_offset % r.block_size == 0, dst_offset % r.block_size + + offset_src = r.offset + extra_size + offset_src_end = r.offset + r.size + if offset_src_end % r.block_size != 0: + offset_src_end += r.block_size - (offset_src_end % r.block_size) + size = offset_src_end - offset_src + os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset) + dst_offset += r.size - extra_size else: - extra_size = alignment - (dst_offset % alignment) - extra_size = min(extra_size, r.size) - src.seek(r.offset) - buf = src.read(extra_size) - fout.seek(dst_offset) - fout.write(buf) - dst_offset += extra_size - if extra_size == r.size: - continue + if r.block_size > 0: + logger.debug(f"misaligned for reflinking, falling back to copy ({i}/{len(ranges)})") + # not trying to use reflinks, but still using os.copy_file_range for speed + os.copy_file_range(src.fileno(), fout.fileno(), r.size, r.offset, dst_offset) + dst_offset += r.size + else: + # not using reflinks, fallback when os.copy_file_range is not supported + src.seek(r.offset) + fout.seek(dst_offset) + shutil.copyfileobj(src, fout, r.size) + dst_offset += r.size - assert dst_offset % alignment == 0, dst_offset % alignment - - offset_src = r.offset + extra_size - offset_src_end = r.offset + r.size - if offset_src_end % alignment != 0: - offset_src_end += alignment - (offset_src_end % alignment) - size = offset_src_end - offset_src - os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset) - dst_offset += r.size - extra_size - - for f in src_files.values(): - f.close() + for f in src_files.values(): + f.close() fout.seek(dst_offset) @@ -399,10 +417,13 @@ class SafetensorsLocal: tensors: dict[str, LocalTensor] - def __init__(self, filename: Path): + def __init__(self, filename: Path, *, reflink: bool = False): + stat = os.stat(filename) + # using the preferred block size to signal whether reflinks are desired when copying + block_size = stat.st_blksize if reflink else -1 with open(filename, "rb") as f: metadata_length = int.from_bytes(f.read(8), byteorder='little') - file_size = os.stat(filename).st_size + file_size = stat.st_size if file_size < 8 + metadata_length: raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}") @@ -427,9 +448,10 @@ class SafetensorsLocal: dtype=meta["dtype"], shape=tuple(meta["shape"]), data_range=LocalTensorRange( - filename, - data_start_offset + meta["data_offsets"][0], - meta["data_offsets"][1] - meta["data_offsets"][0], + filename=filename, + block_size=block_size, + offset=data_start_offset + meta["data_offsets"][0], + size=meta["data_offsets"][1] - meta["data_offsets"][0], ), )