convert : detect filesystem block size for reflinks

* convert : use direct copies when possible

Using os.copy_file_range where available,
and falling back to shutil.copyfileobj otherwise.

* gguf : handle misaligned offset more cleanly
This commit is contained in:
Francis Couture-Harpin
2025-09-04 17:40:11 -04:00
parent 34bd024267
commit 6792f66a93
5 changed files with 132 additions and 104 deletions

View File

@@ -80,6 +80,7 @@ class ModelBase:
is_big_endian: bool
endianess: gguf.GGUFEndian
use_temp_file: bool
use_reflinks: bool
lazy: bool
dry_run: bool
hparams: dict[str, Any]
@@ -119,6 +120,7 @@ class ModelBase:
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.use_reflinks = use_reflinks
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
@@ -133,7 +135,7 @@ class ModelBase:
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
use_reflinks=use_reflinks)
use_reflinks=self.use_reflinks)
# Mistral specific
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
@@ -202,7 +204,7 @@ class ModelBase:
logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if is_safetensors:
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name, reflink=self.use_reflinks))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))

View File

@@ -624,16 +624,16 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ctx->size = 0;
for (size_t i = 0; i < ctx->info.size(); ++i) {
const gguf_tensor_info & ti = ctx->info[i];
// HACK: bypass the continuity check
ctx->size = ti.offset;
if (ti.offset != ctx->size) {
// alignment offset is only necessary for GGUF converted with reflinks
const size_t align_offset = ti.offset % ctx->alignment;
if (ti.offset - align_offset != ctx->size) {
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
__func__, ti.t.name, ti.offset, ctx->size);
__func__, ti.t.name, ti.offset, ctx->size + align_offset);
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
gguf_free(ctx);
return nullptr;
}
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t) + align_offset, ctx->alignment);
if (SIZE_MAX - ctx->size < padded_size) {
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
__func__, ti.t.name, ctx->size, padded_size);

View File

@@ -30,7 +30,7 @@ from .constants import (
)
from .quants import quant_shape_from_byte_shape
from .utility import LocalTensorRange, best_alignment_offset, reflink_tensor_ranges
from .utility import LocalTensorRange, best_extra_offset
logger = logging.getLogger(__name__)
@@ -94,7 +94,7 @@ class GGUFWriter:
self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range")
self.use_temp_file = use_temp_file if not self.use_reflinks else False
self.use_temp_file = False if self.use_reflinks else use_temp_file
self.temp_file = None
self.tensors = [{}]
self.kv_data = [{}]
@@ -110,10 +110,6 @@ class GGUFWriter:
if self.small_first_shard:
self.tensors.append({})
if self.use_reflinks:
# common default block size for COW filesystems
self.add_custom_alignment(4096)
self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
@@ -185,6 +181,15 @@ class GGUFWriter:
self.fout = [open(filename, "wb") for filename in filenames]
self.state = WriterState.EMPTY
if self.use_reflinks:
# reflinks require alignment to the filesystem blocks
block_size = os.stat(self.path.parent).st_blksize
# necessary to get an appropriate data start offset
# when padding for reflinks;
# using the real alignment (8 bytes, from safetensors)
# would result in a unusable base data offset
self.add_custom_alignment(block_size)
def print_plan(self) -> list[Path]:
logger.info("Writing the following files:")
assert self.path is not None
@@ -264,11 +269,11 @@ class GGUFWriter:
offset_tensor = 0
for name, ti in tensors.items():
align_offset = 0
extra_offset = 0
if self.use_reflinks:
ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
if len(ranges) > 0:
align_offset = best_alignment_offset(ranges, self.data_alignment)
extra_offset = best_extra_offset(ranges, offset_tensor)
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
n_dims = len(ti.shape)
@@ -276,8 +281,8 @@ class GGUFWriter:
for j in range(n_dims):
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
ti_data += self._pack("I", ti.dtype)
ti_data += self._pack("Q", offset_tensor + align_offset)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment)
ti_data += self._pack("Q", offset_tensor + extra_offset)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + extra_offset, self.data_alignment)
fout.write(ti_data)
fout.flush()
@@ -405,13 +410,12 @@ class GGUFWriter:
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))
fp.write(b"\x00" * pad)
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
assert not self.use_reflinks # TODO: handle this here too
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
@@ -432,7 +436,7 @@ class GGUFWriter:
self.write_padding(fout, fout.tell())
tensor.tofile(fout)
self.write_padding(fout, tensor.nbytes)
self.write_padding(fout, fout.tell())
self.state = WriterState.WEIGHTS
@@ -467,18 +471,14 @@ class GGUFWriter:
for name, ti in tensors.items():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0:
if self.use_reflinks and len(getattr(ti.tensor, "_ranges", ())) > 0:
logger.debug(f"using reflinks for {name}")
start_offset = fout.tell()
reflink_tensor_ranges(fout, ranges, self.data_alignment)
self.write_padding(fout, fout.tell() - start_offset)
else:
ti.tensor.tofile(fout)
self.write_padding(fout, ti.nbytes)
ti.tensor.tofile(fout)
if shard_bar is not None:
shard_bar.update(ti.nbytes)
if bar is not None:
bar.update(ti.nbytes)
self.write_padding(fout, fout.tell())
ti.tensor = None
else:
self.temp_file.seek(0)

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod
from io import BufferedWriter
import logging
from typing import Any, Callable
import numpy as np
from numpy.typing import DTypeLike
from .utility import LocalTensorRange
from .utility import LocalTensorRange, copy_tensor_ranges
logger = logging.getLogger(__name__)
@@ -224,8 +225,11 @@ class LazyNumpyTensor(LazyBase):
ranges = self._ranges if self._meta.dtype == dtype else ()
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)
return eager.tofile(*args, **kwargs)
def tofile(self, fid, *args, **kwargs):
if isinstance(fid, BufferedWriter) and len(self._ranges) > 0:
return copy_tensor_ranges(fid, self._ranges)
else:
eager = LazyNumpyTensor.to_eager(self)
return eager.tofile(fid, *args, **kwargs)
# TODO: __array_function__

View File

@@ -7,6 +7,7 @@ from typing import Literal
import os
import json
import shutil
import logging
import numpy as np
@@ -281,99 +282,116 @@ class SafetensorRemote:
@dataclass
class LocalTensorRange:
filename: Path
block_size: int
offset: int
size: int
def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
def best_extra_offset(ranges: tuple[LocalTensorRange, ...], current_offset: int) -> int:
hist: dict[int, int] = {}
max_block_size = 0
for r in ranges:
align_offset = r.offset % alignment
if align_offset not in hist:
hist[align_offset] = 0
hist[align_offset] += r.size
# Ensure minimal alignment is 8 bytes (common with safetensors)
# and that the block size is valid
if r.offset % 8 == 0 and r.block_size > 0:
align_offset = r.offset % r.block_size
if align_offset not in hist:
hist[align_offset] = 0
hist[align_offset] += r.size
if r.block_size > max_block_size:
max_block_size = r.block_size
best_offset = 0
best_size = 0
for offset, size in hist.items():
# Ensure minimal alignment is 8-bytes (common with safetensors)
if size > best_size and offset % 8 == 0:
if size > best_size:
best_size = size
best_offset = offset
if max_block_size > 0:
# the offset needs to be aligned properly
# or else there's probably a block size mismatch
assert current_offset % max_block_size == 0, current_offset % max_block_size
return best_offset
# (assuming this is only called where os.copy_file_range is present)
#
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
# to make it more likely that copy-on-write is used where possible.
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
def reflink_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
#
# Falls back to shutil.copyfileobj when os.copy_file_range is not present.
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...]):
assert len(ranges) > 0
dst_offset = fout.tell()
assert dst_offset % alignment == 0, dst_offset % alignment
align_offset = best_alignment_offset(ranges, alignment)
if len(ranges) == 1:
r = ranges[0]
with open(r.filename, "rb") as src:
offset_src = r.offset - align_offset
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size + align_offset
else:
# All ranges need to have the same alignment offset
# Non-consecutive ranges need a patch block in between when the alignment offset is non-zero
src_files: dict[Path, BufferedReader] = {}
for r in ranges:
if r.filename not in src_files:
src_files[r.filename] = open(r.filename, "rb")
extra_offset = best_extra_offset(ranges, dst_offset)
for i, r in enumerate(ranges):
this_align_offset = r.offset % alignment
src = src_files[r.filename]
if this_align_offset != align_offset:
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
# relying on os.copy_file_range to fallback to a non-aligned copy
if extra_offset > 0:
# initial padding
fout.write(b"\x00" * extra_offset)
# Block 0, 1, 2, 3, 4,
# |___0000|0000000|0001111|1111111|111____|
#
# 1. blocks 0, 1 and 2 are copied from range[0] using os.copy_file_range
# 2. block 2 is partially overwritten with contents from range[1]
# 3. blocks 3 and 4 are copied from range[1] using os.copy_file_range
#
# (2 and 3 are repeated with further blocks if there are more ranges)
if i == 0:
extra_size = -align_offset
elif dst_offset % alignment == 0:
extra_size = 0
dst_offset += extra_offset
start_offset = dst_offset
src_files: dict[Path, BufferedReader] = {}
for r in ranges:
if r.filename not in src_files:
src_files[r.filename] = open(r.filename, "rb")
has_copy_file_range = hasattr(os, "copy_file_range")
for i, r in enumerate(ranges):
src = src_files[r.filename]
if has_copy_file_range:
if r.block_size > 0 and (r.offset % r.block_size) == (start_offset % r.block_size):
# Attempting to align copies for reflinking
# Block 0, 1, 2, 3, 4,
# |___0000|0000000|0001111|1111111|111____|
#
# 1. block 0 is partially overwritten with contents from range[0]
# 2. blocks 1 and 2 are copied from range[0] using os.copy_file_range
# 3. block 2 is partially overwritten with contents from range[1]
# 4. blocks 3 and 4 are copied from range[1] using os.copy_file_range
# (repeated for further ranges)
if dst_offset % r.block_size == 0:
extra_size = 0
else:
extra_size = r.block_size - (dst_offset % r.block_size)
extra_size = min(extra_size, r.size)
src.seek(r.offset)
buf = src.read(extra_size)
fout.seek(dst_offset)
fout.write(buf)
dst_offset += extra_size
if extra_size == r.size:
continue
assert dst_offset % r.block_size == 0, dst_offset % r.block_size
offset_src = r.offset + extra_size
offset_src_end = r.offset + r.size
if offset_src_end % r.block_size != 0:
offset_src_end += r.block_size - (offset_src_end % r.block_size)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size - extra_size
else:
extra_size = alignment - (dst_offset % alignment)
extra_size = min(extra_size, r.size)
src.seek(r.offset)
buf = src.read(extra_size)
fout.seek(dst_offset)
fout.write(buf)
dst_offset += extra_size
if extra_size == r.size:
continue
if r.block_size > 0:
logger.debug(f"misaligned for reflinking, falling back to copy ({i}/{len(ranges)})")
# not trying to use reflinks, but still using os.copy_file_range for speed
os.copy_file_range(src.fileno(), fout.fileno(), r.size, r.offset, dst_offset)
dst_offset += r.size
else:
# not using reflinks, fallback when os.copy_file_range is not supported
src.seek(r.offset)
fout.seek(dst_offset)
shutil.copyfileobj(src, fout, r.size)
dst_offset += r.size
assert dst_offset % alignment == 0, dst_offset % alignment
offset_src = r.offset + extra_size
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size - extra_size
for f in src_files.values():
f.close()
for f in src_files.values():
f.close()
fout.seek(dst_offset)
@@ -399,10 +417,13 @@ class SafetensorsLocal:
tensors: dict[str, LocalTensor]
def __init__(self, filename: Path):
def __init__(self, filename: Path, *, reflink: bool = False):
stat = os.stat(filename)
# using the preferred block size to signal whether reflinks are desired when copying
block_size = stat.st_blksize if reflink else -1
with open(filename, "rb") as f:
metadata_length = int.from_bytes(f.read(8), byteorder='little')
file_size = os.stat(filename).st_size
file_size = stat.st_size
if file_size < 8 + metadata_length:
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
@@ -427,9 +448,10 @@ class SafetensorsLocal:
dtype=meta["dtype"],
shape=tuple(meta["shape"]),
data_range=LocalTensorRange(
filename,
data_start_offset + meta["data_offsets"][0],
meta["data_offsets"][1] - meta["data_offsets"][0],
filename=filename,
block_size=block_size,
offset=data_start_offset + meta["data_offsets"][0],
size=meta["data_offsets"][1] - meta["data_offsets"][0],
),
)