mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-30 08:42:00 +00:00
convert : detect filesystem block size for reflinks
* convert : use direct copies when possible Using os.copy_file_range where available, and falling back to shutil.copyfileobj otherwise. * gguf : handle misaligned offset more cleanly
This commit is contained in:
@@ -80,6 +80,7 @@ class ModelBase:
|
|||||||
is_big_endian: bool
|
is_big_endian: bool
|
||||||
endianess: gguf.GGUFEndian
|
endianess: gguf.GGUFEndian
|
||||||
use_temp_file: bool
|
use_temp_file: bool
|
||||||
|
use_reflinks: bool
|
||||||
lazy: bool
|
lazy: bool
|
||||||
dry_run: bool
|
dry_run: bool
|
||||||
hparams: dict[str, Any]
|
hparams: dict[str, Any]
|
||||||
@@ -119,6 +120,7 @@ class ModelBase:
|
|||||||
self.is_big_endian = is_big_endian
|
self.is_big_endian = is_big_endian
|
||||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
|
self.use_reflinks = use_reflinks
|
||||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
self.remote_hf_model_id = remote_hf_model_id
|
self.remote_hf_model_id = remote_hf_model_id
|
||||||
@@ -133,7 +135,7 @@ class ModelBase:
|
|||||||
# Configure GGUF Writer
|
# Configure GGUF Writer
|
||||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
|
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
|
||||||
use_reflinks=use_reflinks)
|
use_reflinks=self.use_reflinks)
|
||||||
|
|
||||||
# Mistral specific
|
# Mistral specific
|
||||||
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
|
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
|
||||||
@@ -202,7 +204,7 @@ class ModelBase:
|
|||||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||||
ctx: ContextManager[Any]
|
ctx: ContextManager[Any]
|
||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
|
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name, reflink=self.use_reflinks))
|
||||||
else:
|
else:
|
||||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||||
|
|
||||||
|
|||||||
@@ -624,16 +624,16 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
|||||||
ctx->size = 0;
|
ctx->size = 0;
|
||||||
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
||||||
const gguf_tensor_info & ti = ctx->info[i];
|
const gguf_tensor_info & ti = ctx->info[i];
|
||||||
// HACK: bypass the continuity check
|
// alignment offset is only necessary for GGUF converted with reflinks
|
||||||
ctx->size = ti.offset;
|
const size_t align_offset = ti.offset % ctx->alignment;
|
||||||
if (ti.offset != ctx->size) {
|
if (ti.offset - align_offset != ctx->size) {
|
||||||
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||||
__func__, ti.t.name, ti.offset, ctx->size);
|
__func__, ti.t.name, ti.offset, ctx->size + align_offset);
|
||||||
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
|
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
|
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t) + align_offset, ctx->alignment);
|
||||||
if (SIZE_MAX - ctx->size < padded_size) {
|
if (SIZE_MAX - ctx->size < padded_size) {
|
||||||
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
|
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
|
||||||
__func__, ti.t.name, ctx->size, padded_size);
|
__func__, ti.t.name, ctx->size, padded_size);
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from .constants import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .quants import quant_shape_from_byte_shape
|
from .quants import quant_shape_from_byte_shape
|
||||||
from .utility import LocalTensorRange, best_alignment_offset, reflink_tensor_ranges
|
from .utility import LocalTensorRange, best_extra_offset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ class GGUFWriter:
|
|||||||
self.endianess = endianess
|
self.endianess = endianess
|
||||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||||
self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range")
|
self.use_reflinks = use_reflinks and hasattr(os, "copy_file_range")
|
||||||
self.use_temp_file = use_temp_file if not self.use_reflinks else False
|
self.use_temp_file = False if self.use_reflinks else use_temp_file
|
||||||
self.temp_file = None
|
self.temp_file = None
|
||||||
self.tensors = [{}]
|
self.tensors = [{}]
|
||||||
self.kv_data = [{}]
|
self.kv_data = [{}]
|
||||||
@@ -110,10 +110,6 @@ class GGUFWriter:
|
|||||||
if self.small_first_shard:
|
if self.small_first_shard:
|
||||||
self.tensors.append({})
|
self.tensors.append({})
|
||||||
|
|
||||||
if self.use_reflinks:
|
|
||||||
# common default block size for COW filesystems
|
|
||||||
self.add_custom_alignment(4096)
|
|
||||||
|
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
|
|
||||||
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
||||||
@@ -185,6 +181,15 @@ class GGUFWriter:
|
|||||||
self.fout = [open(filename, "wb") for filename in filenames]
|
self.fout = [open(filename, "wb") for filename in filenames]
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.EMPTY
|
||||||
|
|
||||||
|
if self.use_reflinks:
|
||||||
|
# reflinks require alignment to the filesystem blocks
|
||||||
|
block_size = os.stat(self.path.parent).st_blksize
|
||||||
|
# necessary to get an appropriate data start offset
|
||||||
|
# when padding for reflinks;
|
||||||
|
# using the real alignment (8 bytes, from safetensors)
|
||||||
|
# would result in a unusable base data offset
|
||||||
|
self.add_custom_alignment(block_size)
|
||||||
|
|
||||||
def print_plan(self) -> list[Path]:
|
def print_plan(self) -> list[Path]:
|
||||||
logger.info("Writing the following files:")
|
logger.info("Writing the following files:")
|
||||||
assert self.path is not None
|
assert self.path is not None
|
||||||
@@ -264,11 +269,11 @@ class GGUFWriter:
|
|||||||
offset_tensor = 0
|
offset_tensor = 0
|
||||||
|
|
||||||
for name, ti in tensors.items():
|
for name, ti in tensors.items():
|
||||||
align_offset = 0
|
extra_offset = 0
|
||||||
if self.use_reflinks:
|
if self.use_reflinks:
|
||||||
ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
|
ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
|
||||||
if len(ranges) > 0:
|
if len(ranges) > 0:
|
||||||
align_offset = best_alignment_offset(ranges, self.data_alignment)
|
extra_offset = best_extra_offset(ranges, offset_tensor)
|
||||||
|
|
||||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||||
n_dims = len(ti.shape)
|
n_dims = len(ti.shape)
|
||||||
@@ -276,8 +281,8 @@ class GGUFWriter:
|
|||||||
for j in range(n_dims):
|
for j in range(n_dims):
|
||||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
|
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
|
||||||
ti_data += self._pack("I", ti.dtype)
|
ti_data += self._pack("I", ti.dtype)
|
||||||
ti_data += self._pack("Q", offset_tensor + align_offset)
|
ti_data += self._pack("Q", offset_tensor + extra_offset)
|
||||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + align_offset, self.data_alignment)
|
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + extra_offset, self.data_alignment)
|
||||||
|
|
||||||
fout.write(ti_data)
|
fout.write(ti_data)
|
||||||
fout.flush()
|
fout.flush()
|
||||||
@@ -405,13 +410,12 @@ class GGUFWriter:
|
|||||||
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
|
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
|
||||||
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
||||||
if pad != 0:
|
if pad != 0:
|
||||||
fp.write(bytes([0] * pad))
|
fp.write(b"\x00" * pad)
|
||||||
|
|
||||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||||
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
||||||
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
||||||
assert self.fout is not None
|
assert self.fout is not None
|
||||||
assert not self.use_reflinks # TODO: handle this here too
|
|
||||||
|
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if self.endianess == GGUFEndian.BIG:
|
||||||
tensor.byteswap(inplace=True)
|
tensor.byteswap(inplace=True)
|
||||||
@@ -432,7 +436,7 @@ class GGUFWriter:
|
|||||||
|
|
||||||
self.write_padding(fout, fout.tell())
|
self.write_padding(fout, fout.tell())
|
||||||
tensor.tofile(fout)
|
tensor.tofile(fout)
|
||||||
self.write_padding(fout, tensor.nbytes)
|
self.write_padding(fout, fout.tell())
|
||||||
|
|
||||||
self.state = WriterState.WEIGHTS
|
self.state = WriterState.WEIGHTS
|
||||||
|
|
||||||
@@ -467,18 +471,14 @@ class GGUFWriter:
|
|||||||
for name, ti in tensors.items():
|
for name, ti in tensors.items():
|
||||||
assert ti.tensor is not None # can only iterate once over the tensors
|
assert ti.tensor is not None # can only iterate once over the tensors
|
||||||
assert ti.tensor.nbytes == ti.nbytes
|
assert ti.tensor.nbytes == ti.nbytes
|
||||||
if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0:
|
if self.use_reflinks and len(getattr(ti.tensor, "_ranges", ())) > 0:
|
||||||
logger.debug(f"using reflinks for {name}")
|
logger.debug(f"using reflinks for {name}")
|
||||||
start_offset = fout.tell()
|
|
||||||
reflink_tensor_ranges(fout, ranges, self.data_alignment)
|
|
||||||
self.write_padding(fout, fout.tell() - start_offset)
|
|
||||||
else:
|
|
||||||
ti.tensor.tofile(fout)
|
ti.tensor.tofile(fout)
|
||||||
self.write_padding(fout, ti.nbytes)
|
|
||||||
if shard_bar is not None:
|
if shard_bar is not None:
|
||||||
shard_bar.update(ti.nbytes)
|
shard_bar.update(ti.nbytes)
|
||||||
if bar is not None:
|
if bar is not None:
|
||||||
bar.update(ti.nbytes)
|
bar.update(ti.nbytes)
|
||||||
|
self.write_padding(fout, fout.tell())
|
||||||
ti.tensor = None
|
ti.tensor = None
|
||||||
else:
|
else:
|
||||||
self.temp_file.seek(0)
|
self.temp_file.seek(0)
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from abc import ABC, ABCMeta, abstractmethod
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
from io import BufferedWriter
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
from .utility import LocalTensorRange
|
from .utility import LocalTensorRange, copy_tensor_ranges
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -224,8 +225,11 @@ class LazyNumpyTensor(LazyBase):
|
|||||||
ranges = self._ranges if self._meta.dtype == dtype else ()
|
ranges = self._ranges if self._meta.dtype == dtype else ()
|
||||||
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
|
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
|
||||||
|
|
||||||
def tofile(self, *args, **kwargs):
|
def tofile(self, fid, *args, **kwargs):
|
||||||
|
if isinstance(fid, BufferedWriter) and len(self._ranges) > 0:
|
||||||
|
return copy_tensor_ranges(fid, self._ranges)
|
||||||
|
else:
|
||||||
eager = LazyNumpyTensor.to_eager(self)
|
eager = LazyNumpyTensor.to_eager(self)
|
||||||
return eager.tofile(*args, **kwargs)
|
return eager.tofile(fid, *args, **kwargs)
|
||||||
|
|
||||||
# TODO: __array_function__
|
# TODO: __array_function__
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -281,78 +282,83 @@ class SafetensorRemote:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LocalTensorRange:
|
class LocalTensorRange:
|
||||||
filename: Path
|
filename: Path
|
||||||
|
block_size: int
|
||||||
offset: int
|
offset: int
|
||||||
size: int
|
size: int
|
||||||
|
|
||||||
|
|
||||||
def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
|
def best_extra_offset(ranges: tuple[LocalTensorRange, ...], current_offset: int) -> int:
|
||||||
hist: dict[int, int] = {}
|
hist: dict[int, int] = {}
|
||||||
|
|
||||||
|
max_block_size = 0
|
||||||
for r in ranges:
|
for r in ranges:
|
||||||
align_offset = r.offset % alignment
|
# Ensure minimal alignment is 8 bytes (common with safetensors)
|
||||||
|
# and that the block size is valid
|
||||||
|
if r.offset % 8 == 0 and r.block_size > 0:
|
||||||
|
align_offset = r.offset % r.block_size
|
||||||
if align_offset not in hist:
|
if align_offset not in hist:
|
||||||
hist[align_offset] = 0
|
hist[align_offset] = 0
|
||||||
hist[align_offset] += r.size
|
hist[align_offset] += r.size
|
||||||
|
if r.block_size > max_block_size:
|
||||||
|
max_block_size = r.block_size
|
||||||
|
|
||||||
best_offset = 0
|
best_offset = 0
|
||||||
best_size = 0
|
best_size = 0
|
||||||
for offset, size in hist.items():
|
for offset, size in hist.items():
|
||||||
# Ensure minimal alignment is 8-bytes (common with safetensors)
|
if size > best_size:
|
||||||
if size > best_size and offset % 8 == 0:
|
|
||||||
best_size = size
|
best_size = size
|
||||||
best_offset = offset
|
best_offset = offset
|
||||||
|
|
||||||
|
if max_block_size > 0:
|
||||||
|
# the offset needs to be aligned properly
|
||||||
|
# or else there's probably a block size mismatch
|
||||||
|
assert current_offset % max_block_size == 0, current_offset % max_block_size
|
||||||
|
|
||||||
return best_offset
|
return best_offset
|
||||||
|
|
||||||
|
|
||||||
# (assuming this is only called where os.copy_file_range is present)
|
|
||||||
#
|
|
||||||
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
|
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
|
||||||
# to make it more likely that copy-on-write is used where possible.
|
# to make it more likely that copy-on-write is used where possible.
|
||||||
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
|
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
|
||||||
def reflink_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
|
#
|
||||||
|
# Falls back to shutil.copyfileobj when os.copy_file_range is not present.
|
||||||
|
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...]):
|
||||||
assert len(ranges) > 0
|
assert len(ranges) > 0
|
||||||
dst_offset = fout.tell()
|
dst_offset = fout.tell()
|
||||||
assert dst_offset % alignment == 0, dst_offset % alignment
|
extra_offset = best_extra_offset(ranges, dst_offset)
|
||||||
align_offset = best_alignment_offset(ranges, alignment)
|
|
||||||
if len(ranges) == 1:
|
if extra_offset > 0:
|
||||||
r = ranges[0]
|
# initial padding
|
||||||
with open(r.filename, "rb") as src:
|
fout.write(b"\x00" * extra_offset)
|
||||||
offset_src = r.offset - align_offset
|
|
||||||
offset_src_end = r.offset + r.size
|
dst_offset += extra_offset
|
||||||
if offset_src_end % alignment != 0:
|
start_offset = dst_offset
|
||||||
offset_src_end += alignment - (offset_src_end % alignment)
|
|
||||||
size = offset_src_end - offset_src
|
|
||||||
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
|
|
||||||
dst_offset += r.size + align_offset
|
|
||||||
else:
|
|
||||||
# All ranges need to have the same alignment offset
|
|
||||||
# Non-consecutive ranges need a patch block in between when the alignment offset is non-zero
|
|
||||||
src_files: dict[Path, BufferedReader] = {}
|
src_files: dict[Path, BufferedReader] = {}
|
||||||
for r in ranges:
|
for r in ranges:
|
||||||
if r.filename not in src_files:
|
if r.filename not in src_files:
|
||||||
src_files[r.filename] = open(r.filename, "rb")
|
src_files[r.filename] = open(r.filename, "rb")
|
||||||
|
|
||||||
|
has_copy_file_range = hasattr(os, "copy_file_range")
|
||||||
|
|
||||||
for i, r in enumerate(ranges):
|
for i, r in enumerate(ranges):
|
||||||
this_align_offset = r.offset % alignment
|
|
||||||
src = src_files[r.filename]
|
src = src_files[r.filename]
|
||||||
if this_align_offset != align_offset:
|
if has_copy_file_range:
|
||||||
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
|
if r.block_size > 0 and (r.offset % r.block_size) == (start_offset % r.block_size):
|
||||||
# relying on os.copy_file_range to fallback to a non-aligned copy
|
# Attempting to align copies for reflinking
|
||||||
|
|
||||||
# Block 0, 1, 2, 3, 4,
|
# Block 0, 1, 2, 3, 4,
|
||||||
# |___0000|0000000|0001111|1111111|111____|
|
# |___0000|0000000|0001111|1111111|111____|
|
||||||
#
|
#
|
||||||
# 1. blocks 0, 1 and 2 are copied from range[0] using os.copy_file_range
|
# 1. block 0 is partially overwritten with contents from range[0]
|
||||||
# 2. block 2 is partially overwritten with contents from range[1]
|
# 2. blocks 1 and 2 are copied from range[0] using os.copy_file_range
|
||||||
# 3. blocks 3 and 4 are copied from range[1] using os.copy_file_range
|
# 3. block 2 is partially overwritten with contents from range[1]
|
||||||
#
|
# 4. blocks 3 and 4 are copied from range[1] using os.copy_file_range
|
||||||
# (2 and 3 are repeated with further blocks if there are more ranges)
|
# (repeated for further ranges)
|
||||||
if i == 0:
|
if dst_offset % r.block_size == 0:
|
||||||
extra_size = -align_offset
|
|
||||||
elif dst_offset % alignment == 0:
|
|
||||||
extra_size = 0
|
extra_size = 0
|
||||||
else:
|
else:
|
||||||
extra_size = alignment - (dst_offset % alignment)
|
extra_size = r.block_size - (dst_offset % r.block_size)
|
||||||
extra_size = min(extra_size, r.size)
|
extra_size = min(extra_size, r.size)
|
||||||
src.seek(r.offset)
|
src.seek(r.offset)
|
||||||
buf = src.read(extra_size)
|
buf = src.read(extra_size)
|
||||||
@@ -362,15 +368,27 @@ def reflink_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange,
|
|||||||
if extra_size == r.size:
|
if extra_size == r.size:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert dst_offset % alignment == 0, dst_offset % alignment
|
assert dst_offset % r.block_size == 0, dst_offset % r.block_size
|
||||||
|
|
||||||
offset_src = r.offset + extra_size
|
offset_src = r.offset + extra_size
|
||||||
offset_src_end = r.offset + r.size
|
offset_src_end = r.offset + r.size
|
||||||
if offset_src_end % alignment != 0:
|
if offset_src_end % r.block_size != 0:
|
||||||
offset_src_end += alignment - (offset_src_end % alignment)
|
offset_src_end += r.block_size - (offset_src_end % r.block_size)
|
||||||
size = offset_src_end - offset_src
|
size = offset_src_end - offset_src
|
||||||
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
|
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
|
||||||
dst_offset += r.size - extra_size
|
dst_offset += r.size - extra_size
|
||||||
|
else:
|
||||||
|
if r.block_size > 0:
|
||||||
|
logger.debug(f"misaligned for reflinking, falling back to copy ({i}/{len(ranges)})")
|
||||||
|
# not trying to use reflinks, but still using os.copy_file_range for speed
|
||||||
|
os.copy_file_range(src.fileno(), fout.fileno(), r.size, r.offset, dst_offset)
|
||||||
|
dst_offset += r.size
|
||||||
|
else:
|
||||||
|
# not using reflinks, fallback when os.copy_file_range is not supported
|
||||||
|
src.seek(r.offset)
|
||||||
|
fout.seek(dst_offset)
|
||||||
|
shutil.copyfileobj(src, fout, r.size)
|
||||||
|
dst_offset += r.size
|
||||||
|
|
||||||
for f in src_files.values():
|
for f in src_files.values():
|
||||||
f.close()
|
f.close()
|
||||||
@@ -399,10 +417,13 @@ class SafetensorsLocal:
|
|||||||
|
|
||||||
tensors: dict[str, LocalTensor]
|
tensors: dict[str, LocalTensor]
|
||||||
|
|
||||||
def __init__(self, filename: Path):
|
def __init__(self, filename: Path, *, reflink: bool = False):
|
||||||
|
stat = os.stat(filename)
|
||||||
|
# using the preferred block size to signal whether reflinks are desired when copying
|
||||||
|
block_size = stat.st_blksize if reflink else -1
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
metadata_length = int.from_bytes(f.read(8), byteorder='little')
|
metadata_length = int.from_bytes(f.read(8), byteorder='little')
|
||||||
file_size = os.stat(filename).st_size
|
file_size = stat.st_size
|
||||||
if file_size < 8 + metadata_length:
|
if file_size < 8 + metadata_length:
|
||||||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
|
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
|
||||||
|
|
||||||
@@ -427,9 +448,10 @@ class SafetensorsLocal:
|
|||||||
dtype=meta["dtype"],
|
dtype=meta["dtype"],
|
||||||
shape=tuple(meta["shape"]),
|
shape=tuple(meta["shape"]),
|
||||||
data_range=LocalTensorRange(
|
data_range=LocalTensorRange(
|
||||||
filename,
|
filename=filename,
|
||||||
data_start_offset + meta["data_offsets"][0],
|
block_size=block_size,
|
||||||
meta["data_offsets"][1] - meta["data_offsets"][0],
|
offset=data_start_offset + meta["data_offsets"][0],
|
||||||
|
size=meta["data_offsets"][1] - meta["data_offsets"][0],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user