convert : fix reflinks for stacked MoE tensors

This commit is contained in:
Francis Couture-Harpin
2025-09-02 15:22:01 -04:00
parent 562aa42c12
commit d921057027
4 changed files with 31 additions and 14 deletions

View File

@@ -297,7 +297,8 @@ def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
best_offset = 0
best_size = 0
for offset, size in hist.items():
if size > best_size:
# Ensure minimal alignment is 8-bytes (common with safetensors)
if size > best_size and offset % 8 == 0:
best_size = size
best_offset = offset
return best_offset
@@ -307,7 +308,7 @@ def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
# to make it more likely that copy-on-write is used where possible.
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
def reflink_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
assert len(ranges) > 0
dst_offset = fout.tell()
assert dst_offset % alignment == 0, dst_offset % alignment
@@ -335,26 +336,40 @@ def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...
src = src_files[r.filename]
if this_align_offset != align_offset:
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
if i > 0 and dst_offset % alignment != 0:
# Write the correct data between blocks even when they are non-consecutive
# relying on os.copy_file_range to fallback to a non-aligned copy
# Block 0, 1, 2, 3, 4,
# |___0000|0000000|0001111|1111111|111____|
#
# 1. blocks 0, 1 and 2 are copied from range[0] using os.copy_file_range
# 2. block 2 is partially overwritten with contents from range[1]
# 3. blocks 3 and 4 are copied from range[1] using os.copy_file_range
#
# (2 and 3 are repeated with further blocks if there are more ranges)
if i == 0:
extra_size = -align_offset
elif dst_offset % alignment == 0:
extra_size = 0
else:
extra_size = alignment - (dst_offset % alignment)
extra_size = min(extra_size, r.size)
src.seek(r.offset)
buf = src.read(extra_size)
fout.seek(dst_offset)
fout.write(buf)
dst_offset += extra_size
assert dst_offset % alignment == 0, dst_offset % alignment
offset_src = r.offset + extra_size
else:
# TODO: is this always correct?
offset_src = r.offset - align_offset
if extra_size == r.size:
continue
assert dst_offset % alignment == 0, dst_offset % alignment
offset_src = r.offset + extra_size
offset_src_end = r.offset + r.size
if offset_src_end % alignment != 0:
offset_src_end += alignment - (offset_src_end % alignment)
size = offset_src_end - offset_src
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
dst_offset += r.size
dst_offset += r.size - extra_size
for f in src_files.values():
f.close()