mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	gguf-py : improve reflink size logging
* gguf-py : move reflinking functions to lazy
This commit is contained in:
		@@ -29,8 +29,8 @@ from .constants import (
 | 
				
			|||||||
    ExpertGatingFuncType,
 | 
					    ExpertGatingFuncType,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .lazy import best_extra_offset, count_reflinkable_size
 | 
				
			||||||
from .quants import quant_shape_from_byte_shape
 | 
					from .quants import quant_shape_from_byte_shape
 | 
				
			||||||
from .utility import LocalTensorRange, best_extra_offset
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -192,7 +192,7 @@ class GGUFWriter:
 | 
				
			|||||||
                    # insert at the start of the key-values
 | 
					                    # insert at the start of the key-values
 | 
				
			||||||
                    if Keys.General.ALIGNMENT in kv:
 | 
					                    if Keys.General.ALIGNMENT in kv:
 | 
				
			||||||
                        del kv[Keys.General.ALIGNMENT]
 | 
					                        del kv[Keys.General.ALIGNMENT]
 | 
				
			||||||
                    self.kv_data[i] = { Keys.General.ALIGNMENT: GGUFValue(block_size, GGUFValueType.UINT32), **kv }
 | 
					                    self.kv_data[i] = {Keys.General.ALIGNMENT: GGUFValue(block_size, GGUFValueType.UINT32), **kv}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_plan(self) -> list[Path]:
 | 
					    def print_plan(self) -> list[Path]:
 | 
				
			||||||
        logger.info("Writing the following files:")
 | 
					        logger.info("Writing the following files:")
 | 
				
			||||||
@@ -200,7 +200,9 @@ class GGUFWriter:
 | 
				
			|||||||
        filenames = self.format_shard_names(self.path)
 | 
					        filenames = self.format_shard_names(self.path)
 | 
				
			||||||
        assert len(filenames) == len(self.tensors)
 | 
					        assert len(filenames) == len(self.tensors)
 | 
				
			||||||
        for name, tensors in zip(filenames, self.tensors):
 | 
					        for name, tensors in zip(filenames, self.tensors):
 | 
				
			||||||
            logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
 | 
					            total_size = sum(ti.nbytes for ti in tensors.values())
 | 
				
			||||||
 | 
					            reflinkable_size = count_reflinkable_size(ti.tensor for ti in tensors.values()) if self.use_reflinks else 0
 | 
				
			||||||
 | 
					            logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(total_size)}{', reflinked = ' + GGUFWriter.format_n_bytes_to_str(total_size - reflinkable_size) if self.use_reflinks else ''}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.dry_run:
 | 
					        if self.dry_run:
 | 
				
			||||||
            logger.info("Dry run, not writing files")
 | 
					            logger.info("Dry run, not writing files")
 | 
				
			||||||
@@ -275,9 +277,7 @@ class GGUFWriter:
 | 
				
			|||||||
            for name, ti in tensors.items():
 | 
					            for name, ti in tensors.items():
 | 
				
			||||||
                extra_offset = 0
 | 
					                extra_offset = 0
 | 
				
			||||||
                if self.use_reflinks:
 | 
					                if self.use_reflinks:
 | 
				
			||||||
                    ranges: tuple[LocalTensorRange, ...] = getattr(ti.tensor, "_ranges", ())
 | 
					                    extra_offset = best_extra_offset(ti.tensor, offset_tensor)
 | 
				
			||||||
                    if len(ranges) > 0:
 | 
					 | 
				
			||||||
                        extra_offset = best_extra_offset(ranges, offset_tensor)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
 | 
					                ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
 | 
				
			||||||
                n_dims = len(ti.shape)
 | 
					                n_dims = len(ti.shape)
 | 
				
			||||||
@@ -472,11 +472,9 @@ class GGUFWriter:
 | 
				
			|||||||
                    shard_bar.reset(total=(total if total > 0 else None))
 | 
					                    shard_bar.reset(total=(total if total > 0 else None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # relying on the fact that Python dicts preserve insertion order (since 3.7)
 | 
					                # relying on the fact that Python dicts preserve insertion order (since 3.7)
 | 
				
			||||||
                for name, ti in tensors.items():
 | 
					                for ti in tensors.values():
 | 
				
			||||||
                    assert ti.tensor is not None  # can only iterate once over the tensors
 | 
					                    assert ti.tensor is not None  # can only iterate once over the tensors
 | 
				
			||||||
                    assert ti.tensor.nbytes == ti.nbytes
 | 
					                    assert ti.tensor.nbytes == ti.nbytes
 | 
				
			||||||
                    if self.use_reflinks and len(getattr(ti.tensor, "_ranges", ())) > 0:
 | 
					 | 
				
			||||||
                        logger.debug(f"using reflinks for {name}")
 | 
					 | 
				
			||||||
                    ti.tensor.tofile(fout)
 | 
					                    ti.tensor.tofile(fout)
 | 
				
			||||||
                    if shard_bar is not None:
 | 
					                    if shard_bar is not None:
 | 
				
			||||||
                        shard_bar.update(ti.nbytes)
 | 
					                        shard_bar.update(ti.nbytes)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,18 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					from __future__ import annotations
 | 
				
			||||||
from abc import ABC, ABCMeta, abstractmethod
 | 
					from abc import ABC, ABCMeta, abstractmethod
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from io import BufferedWriter
 | 
					from io import BufferedReader, BufferedWriter
 | 
				
			||||||
import logging
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Any, Callable
 | 
					from typing import Any, Callable, Iterable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import shutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from numpy.typing import DTypeLike
 | 
					from numpy.typing import DTypeLike
 | 
				
			||||||
from .utility import LocalTensorRange, copy_tensor_ranges
 | 
					
 | 
				
			||||||
 | 
					from .utility import LocalTensorRange
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
@@ -210,6 +215,7 @@ class LazyNumpyTensor(LazyBase):
 | 
				
			|||||||
    _tensor_type = np.ndarray
 | 
					    _tensor_type = np.ndarray
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    shape: tuple[int, ...]  # Makes the type checker happy in quants.py
 | 
					    shape: tuple[int, ...]  # Makes the type checker happy in quants.py
 | 
				
			||||||
 | 
					    nbytes: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
 | 
					    def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
 | 
				
			||||||
@@ -227,9 +233,140 @@ class LazyNumpyTensor(LazyBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def tofile(self, fid, *args, **kwargs):
 | 
					    def tofile(self, fid, *args, **kwargs):
 | 
				
			||||||
        if isinstance(fid, BufferedWriter) and len(self._ranges) > 0:
 | 
					        if isinstance(fid, BufferedWriter) and len(self._ranges) > 0:
 | 
				
			||||||
            return copy_tensor_ranges(fid, self._ranges)
 | 
					            return copy_tensor_ranges(self, fid)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            eager = LazyNumpyTensor.to_eager(self)
 | 
					            eager = LazyNumpyTensor.to_eager(self)
 | 
				
			||||||
            return eager.tofile(fid, *args, **kwargs)
 | 
					            return eager.tofile(fid, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: __array_function__
 | 
					    # TODO: __array_function__
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# For aligning blocks when reflinking
 | 
				
			||||||
 | 
					def best_extra_offset(t: np.ndarray | LazyNumpyTensor | None, current_offset: int) -> int:
 | 
				
			||||||
 | 
					    if not isinstance(t, LazyNumpyTensor):
 | 
				
			||||||
 | 
					        # no file ranges, no need for an offset
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ranges = t._ranges
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    histogram: dict[int, int] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    max_block_size = 0
 | 
				
			||||||
 | 
					    for r in ranges:
 | 
				
			||||||
 | 
					        # Ensure minimal alignment is 8 bytes (common with safetensors)
 | 
				
			||||||
 | 
					        # and that the block size is valid
 | 
				
			||||||
 | 
					        if r.offset % 8 == 0 and r.block_size > 0:
 | 
				
			||||||
 | 
					            align_offset = r.offset % r.block_size
 | 
				
			||||||
 | 
					            if align_offset not in histogram:
 | 
				
			||||||
 | 
					                histogram[align_offset] = 0
 | 
				
			||||||
 | 
					            histogram[align_offset] += r.size
 | 
				
			||||||
 | 
					            if r.block_size > max_block_size:
 | 
				
			||||||
 | 
					                max_block_size = r.block_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    best_offset = 0
 | 
				
			||||||
 | 
					    best_size = 0
 | 
				
			||||||
 | 
					    for offset, size in histogram.items():
 | 
				
			||||||
 | 
					        if size > best_size:
 | 
				
			||||||
 | 
					            best_size = size
 | 
				
			||||||
 | 
					            best_offset = offset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if max_block_size > 0:
 | 
				
			||||||
 | 
					        # the offset needs to be aligned properly
 | 
				
			||||||
 | 
					        # or else there's probably a block size mismatch
 | 
				
			||||||
 | 
					        assert current_offset % max_block_size == 0, current_offset % max_block_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return best_offset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def count_reflinkable_size(tensors: Iterable[np.ndarray | LazyNumpyTensor | None]) -> int:
 | 
				
			||||||
 | 
					    if not hasattr(os, "copy_file_range"):
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    size = 0
 | 
				
			||||||
 | 
					    for t in tensors:
 | 
				
			||||||
 | 
					        if isinstance(t, LazyNumpyTensor) and len(t._ranges) > 0:
 | 
				
			||||||
 | 
					            align_offset = best_extra_offset(t, 0)
 | 
				
			||||||
 | 
					            for range in t._ranges:
 | 
				
			||||||
 | 
					                if range.block_size > 0 and range.offset % range.block_size == align_offset:
 | 
				
			||||||
 | 
					                    size += range.size
 | 
				
			||||||
 | 
					    return size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
 | 
				
			||||||
 | 
					# to make it more likely that copy-on-write is used where possible.
 | 
				
			||||||
 | 
					# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Falls back to shutil.copyfileobj when os.copy_file_range is not present.
 | 
				
			||||||
 | 
					def copy_tensor_ranges(t: LazyNumpyTensor, fout: BufferedWriter):
 | 
				
			||||||
 | 
					    ranges = t._ranges
 | 
				
			||||||
 | 
					    assert len(ranges) > 0
 | 
				
			||||||
 | 
					    dst_offset = fout.tell()
 | 
				
			||||||
 | 
					    extra_offset = best_extra_offset(t, dst_offset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if extra_offset > 0:
 | 
				
			||||||
 | 
					        # initial padding
 | 
				
			||||||
 | 
					        fout.write(b"\x00" * extra_offset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dst_offset += extra_offset
 | 
				
			||||||
 | 
					    start_offset = dst_offset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    src_files: dict[Path, BufferedReader] = {}
 | 
				
			||||||
 | 
					    for r in ranges:
 | 
				
			||||||
 | 
					        if r.filename not in src_files:
 | 
				
			||||||
 | 
					            src_files[r.filename] = open(r.filename, "rb")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    has_copy_file_range = hasattr(os, "copy_file_range")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i, r in enumerate(ranges):
 | 
				
			||||||
 | 
					        src = src_files[r.filename]
 | 
				
			||||||
 | 
					        if has_copy_file_range:
 | 
				
			||||||
 | 
					            if r.block_size > 0 and (r.offset % r.block_size) == (start_offset % r.block_size):
 | 
				
			||||||
 | 
					                # Attempting to align copies for reflinking
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Block  0,      1,      2,      3,      4,
 | 
				
			||||||
 | 
					                # |___0000|0000000|0001111|1111111|111____|
 | 
				
			||||||
 | 
					                #
 | 
				
			||||||
 | 
					                # 1. block 0 is partially overwritten with contents from range[0]
 | 
				
			||||||
 | 
					                # 2. blocks 1 and 2 are copied from range[0] using os.copy_file_range
 | 
				
			||||||
 | 
					                # 3. block 2 is partially overwritten with contents from range[1]
 | 
				
			||||||
 | 
					                # 4. blocks 3 and 4 are copied from range[1] using os.copy_file_range
 | 
				
			||||||
 | 
					                # (repeated for further ranges)
 | 
				
			||||||
 | 
					                if dst_offset % r.block_size == 0:
 | 
				
			||||||
 | 
					                    extra_size = 0
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    extra_size = r.block_size - (dst_offset % r.block_size)
 | 
				
			||||||
 | 
					                    extra_size = min(extra_size, r.size)
 | 
				
			||||||
 | 
					                    src.seek(r.offset)
 | 
				
			||||||
 | 
					                    buf = src.read(extra_size)
 | 
				
			||||||
 | 
					                    fout.seek(dst_offset)
 | 
				
			||||||
 | 
					                    fout.write(buf)
 | 
				
			||||||
 | 
					                    dst_offset += extra_size
 | 
				
			||||||
 | 
					                    if extra_size == r.size:
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                assert dst_offset % r.block_size == 0, dst_offset % r.block_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                offset_src = r.offset + extra_size
 | 
				
			||||||
 | 
					                offset_src_end = r.offset + r.size
 | 
				
			||||||
 | 
					                if offset_src_end % r.block_size != 0:
 | 
				
			||||||
 | 
					                    offset_src_end += r.block_size - (offset_src_end % r.block_size)
 | 
				
			||||||
 | 
					                size = offset_src_end - offset_src
 | 
				
			||||||
 | 
					                os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
 | 
				
			||||||
 | 
					                dst_offset += r.size - extra_size
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if r.block_size > 0:
 | 
				
			||||||
 | 
					                    logger.debug(f"misaligned for reflinking, falling back to copy ({i}/{len(ranges)})")
 | 
				
			||||||
 | 
					                # not trying to use reflinks, but still using os.copy_file_range for speed
 | 
				
			||||||
 | 
					                os.copy_file_range(src.fileno(), fout.fileno(), r.size, r.offset, dst_offset)
 | 
				
			||||||
 | 
					                dst_offset += r.size
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # not using reflinks, fallback when os.copy_file_range is not supported
 | 
				
			||||||
 | 
					            src.seek(r.offset)
 | 
				
			||||||
 | 
					            fout.seek(dst_offset)
 | 
				
			||||||
 | 
					            shutil.copyfileobj(src, fout, r.size)
 | 
				
			||||||
 | 
					            dst_offset += r.size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for f in src_files.values():
 | 
				
			||||||
 | 
					        f.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fout.seek(dst_offset)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,11 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					from __future__ import annotations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from io import BufferedReader, BufferedWriter
 | 
					 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Literal
 | 
					from typing import Literal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import shutil
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -287,115 +285,6 @@ class LocalTensorRange:
 | 
				
			|||||||
    size: int
 | 
					    size: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def best_extra_offset(ranges: tuple[LocalTensorRange, ...], current_offset: int) -> int:
 | 
					 | 
				
			||||||
    hist: dict[int, int] = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    max_block_size = 0
 | 
					 | 
				
			||||||
    for r in ranges:
 | 
					 | 
				
			||||||
        # Ensure minimal alignment is 8 bytes (common with safetensors)
 | 
					 | 
				
			||||||
        # and that the block size is valid
 | 
					 | 
				
			||||||
        if r.offset % 8 == 0 and r.block_size > 0:
 | 
					 | 
				
			||||||
            align_offset = r.offset % r.block_size
 | 
					 | 
				
			||||||
            if align_offset not in hist:
 | 
					 | 
				
			||||||
                hist[align_offset] = 0
 | 
					 | 
				
			||||||
            hist[align_offset] += r.size
 | 
					 | 
				
			||||||
            if r.block_size > max_block_size:
 | 
					 | 
				
			||||||
                max_block_size = r.block_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    best_offset = 0
 | 
					 | 
				
			||||||
    best_size = 0
 | 
					 | 
				
			||||||
    for offset, size in hist.items():
 | 
					 | 
				
			||||||
        if size > best_size:
 | 
					 | 
				
			||||||
            best_size = size
 | 
					 | 
				
			||||||
            best_offset = offset
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if max_block_size > 0:
 | 
					 | 
				
			||||||
        # the offset needs to be aligned properly
 | 
					 | 
				
			||||||
        # or else there's probably a block size mismatch
 | 
					 | 
				
			||||||
        assert current_offset % max_block_size == 0, current_offset % max_block_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return best_offset
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
 | 
					 | 
				
			||||||
# to make it more likely that copy-on-write is used where possible.
 | 
					 | 
				
			||||||
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Falls back to shutil.copyfileobj when os.copy_file_range is not present.
 | 
					 | 
				
			||||||
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...]):
 | 
					 | 
				
			||||||
    assert len(ranges) > 0
 | 
					 | 
				
			||||||
    dst_offset = fout.tell()
 | 
					 | 
				
			||||||
    extra_offset = best_extra_offset(ranges, dst_offset)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if extra_offset > 0:
 | 
					 | 
				
			||||||
        # initial padding
 | 
					 | 
				
			||||||
        fout.write(b"\x00" * extra_offset)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    dst_offset += extra_offset
 | 
					 | 
				
			||||||
    start_offset = dst_offset
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    src_files: dict[Path, BufferedReader] = {}
 | 
					 | 
				
			||||||
    for r in ranges:
 | 
					 | 
				
			||||||
        if r.filename not in src_files:
 | 
					 | 
				
			||||||
            src_files[r.filename] = open(r.filename, "rb")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    has_copy_file_range = hasattr(os, "copy_file_range")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for i, r in enumerate(ranges):
 | 
					 | 
				
			||||||
        src = src_files[r.filename]
 | 
					 | 
				
			||||||
        if has_copy_file_range:
 | 
					 | 
				
			||||||
            if r.block_size > 0 and (r.offset % r.block_size) == (start_offset % r.block_size):
 | 
					 | 
				
			||||||
                # Attempting to align copies for reflinking
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Block  0,      1,      2,      3,      4,
 | 
					 | 
				
			||||||
                # |___0000|0000000|0001111|1111111|111____|
 | 
					 | 
				
			||||||
                #
 | 
					 | 
				
			||||||
                # 1. block 0 is partially overwritten with contents from range[0]
 | 
					 | 
				
			||||||
                # 2. blocks 1 and 2 are copied from range[0] using os.copy_file_range
 | 
					 | 
				
			||||||
                # 3. block 2 is partially overwritten with contents from range[1]
 | 
					 | 
				
			||||||
                # 4. blocks 3 and 4 are copied from range[1] using os.copy_file_range
 | 
					 | 
				
			||||||
                # (repeated for further ranges)
 | 
					 | 
				
			||||||
                if dst_offset % r.block_size == 0:
 | 
					 | 
				
			||||||
                    extra_size = 0
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    extra_size = r.block_size - (dst_offset % r.block_size)
 | 
					 | 
				
			||||||
                    extra_size = min(extra_size, r.size)
 | 
					 | 
				
			||||||
                    src.seek(r.offset)
 | 
					 | 
				
			||||||
                    buf = src.read(extra_size)
 | 
					 | 
				
			||||||
                    fout.seek(dst_offset)
 | 
					 | 
				
			||||||
                    fout.write(buf)
 | 
					 | 
				
			||||||
                    dst_offset += extra_size
 | 
					 | 
				
			||||||
                    if extra_size == r.size:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                assert dst_offset % r.block_size == 0, dst_offset % r.block_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                offset_src = r.offset + extra_size
 | 
					 | 
				
			||||||
                offset_src_end = r.offset + r.size
 | 
					 | 
				
			||||||
                if offset_src_end % r.block_size != 0:
 | 
					 | 
				
			||||||
                    offset_src_end += r.block_size - (offset_src_end % r.block_size)
 | 
					 | 
				
			||||||
                size = offset_src_end - offset_src
 | 
					 | 
				
			||||||
                os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
 | 
					 | 
				
			||||||
                dst_offset += r.size - extra_size
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                if r.block_size > 0:
 | 
					 | 
				
			||||||
                    logger.debug(f"misaligned for reflinking, falling back to copy ({i}/{len(ranges)})")
 | 
					 | 
				
			||||||
                # not trying to use reflinks, but still using os.copy_file_range for speed
 | 
					 | 
				
			||||||
                os.copy_file_range(src.fileno(), fout.fileno(), r.size, r.offset, dst_offset)
 | 
					 | 
				
			||||||
                dst_offset += r.size
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # not using reflinks, fallback when os.copy_file_range is not supported
 | 
					 | 
				
			||||||
            src.seek(r.offset)
 | 
					 | 
				
			||||||
            fout.seek(dst_offset)
 | 
					 | 
				
			||||||
            shutil.copyfileobj(src, fout, r.size)
 | 
					 | 
				
			||||||
            dst_offset += r.size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for f in src_files.values():
 | 
					 | 
				
			||||||
        f.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    fout.seek(dst_offset)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class LocalTensor:
 | 
					class LocalTensor:
 | 
				
			||||||
    dtype: str
 | 
					    dtype: str
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user