convert : parse safetensors directly

This commit is contained in:
Francis Couture-Harpin
2025-08-29 11:49:09 -04:00
parent 0d5cfed596
commit ca8f736fe4
2 changed files with 93 additions and 9 deletions

View File

@@ -195,8 +195,7 @@ class ModelBase:
logger.info(f"gguf: indexing model part '{part_name}'") logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any] ctx: ContextManager[Any]
if is_safetensors: if is_safetensors:
from safetensors import safe_open ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else: else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
@@ -205,18 +204,18 @@ class ModelBase:
for name in model_part.keys(): for name in model_part.keys():
if is_safetensors: if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
if self.lazy: if self.lazy:
data = model_part.get_slice(name) data_gen = lambda data=data: LazyTorchTensor.from_safetensors_meta(data) # noqa: E731
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
else: else:
data = model_part.get_tensor(name) dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data: data # noqa: E731 data_gen = lambda data=data: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else: else:
data = model_part[name] data_torch: Tensor = model_part[name]
if self.lazy: if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731 data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else: else:
data_gen = lambda data=data: data # noqa: E731 data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen tensors[name] = data_gen
# verify tensor name presence and identify potentially missing files # verify tensor name presence and identify potentially missing files
@@ -8860,6 +8859,16 @@ class LazyTorchTensor(gguf.LazyBase):
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy) return cast(torch.Tensor, lazy)
@classmethod
def from_safetensors_meta(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
return cast(torch.Tensor, lazy)
@classmethod @classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype] dtype = cls._dtype_str_map[remote_tensor.dtype]

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Literal from typing import Literal
import os import os
import json import json
import numpy as np
def fill_templated_filename(filename: str, output_type: str | None) -> str: def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -266,3 +268,76 @@ class SafetensorRemote:
if os.environ.get("HF_TOKEN"): if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
return headers return headers
@dataclass
class LocalTensorRange:
filename: Path
offset: int
size: int
@dataclass
class LocalTensor:
dtype: str
shape: tuple[int, ...]
data_range: LocalTensorRange
def mmap_bytes(self) -> np.ndarray:
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
class SafetensorsLocal:
"""
Read a safetensors file from the local filesystem.
Custom parsing gives a bit more control over the memory usage.
The official safetensors library doesn't expose file ranges.
"""
ALIGNMENT = 8 # bytes
tensors: dict[str, LocalTensor]
def __init__(self, filename: Path):
with open(filename, "rb") as f:
metadata_length = int.from_bytes(f.read(8), byteorder='little')
file_size = os.stat(filename).st_size
if file_size < 8 + metadata_length:
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
metadata_str = f.read(metadata_length).decode('utf-8')
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
data_start_offset = f.tell()
alignment = self.ALIGNMENT
if data_start_offset % alignment != 0:
data_start_offset += alignment - (data_start_offset % alignment)
tensors: dict[str, LocalTensor] = {}
for name, meta in metadata.items():
if name == "__metadata__":
# ignore metadata, it's not a tensor
continue
tensors[name] = LocalTensor(
dtype=meta["dtype"],
shape=tuple(meta["shape"]),
data_range=LocalTensorRange(
filename,
data_start_offset + meta["data_offsets"][0],
meta["data_offsets"][1] - meta["data_offsets"][0],
),
)
# order by offset
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[1].data_range.offset))
def __enter__(self, *args, **kwargs):
del args, kwargs # unused
return self.tensors
def __exit__(self, *args, **kwargs):
del args, kwargs # unused