convert : parse safetensors directly (#15667)

* convert : parse safetensors directly

* gguf-py : order safetensors tensors by name

Applies to both local and remote safetensors custom parsing.
This matches the behavior of the official safetensors implementation.

* convert : rename from_safetensors_meta to from_local_tensor

For consistency with from_remote_tensor

* convert : fix no-lazy dtypes from direct safetensors
This commit is contained in:
compilade
2025-11-09 09:49:40 -05:00
committed by GitHub
parent 1c07c0c68c
commit 802cef44bf
2 changed files with 98 additions and 9 deletions

View File

@@ -218,8 +218,7 @@ class ModelBase:
logger.info(f"gguf: indexing model part '{part_name}'") logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any] ctx: ContextManager[Any]
if is_safetensors: if is_safetensors:
from safetensors import safe_open ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else: else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
@@ -228,18 +227,18 @@ class ModelBase:
for name in model_part.keys(): for name in model_part.keys():
if is_safetensors: if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
if self.lazy: if self.lazy:
data = model_part.get_slice(name) data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
else: else:
data = model_part.get_tensor(name) dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data: data # noqa: E731 data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else: else:
data = model_part[name] data_torch: Tensor = model_part[name]
if self.lazy: if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731 data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else: else:
data_gen = lambda data=data: data # noqa: E731 data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen tensors[name] = data_gen
# verify tensor name presence and identify potentially missing files # verify tensor name presence and identify potentially missing files
@@ -10079,6 +10078,16 @@ class LazyTorchTensor(gguf.LazyBase):
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:]) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
return cast(torch.Tensor, lazy) return cast(torch.Tensor, lazy)
@classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
return cast(torch.Tensor, lazy)
@classmethod @classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype] dtype = cls._dtype_str_map[remote_tensor.dtype]

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Literal from typing import Literal
import os import os
import json import json
import numpy as np
def fill_templated_filename(filename: str, output_type: str | None) -> str: def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -177,6 +179,10 @@ class SafetensorRemote:
except KeyError as e: except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
# order by name (same as default safetensors behavior)
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
res = dict(sorted(res.items(), key=lambda t: t[0]))
return res return res
@classmethod @classmethod
@@ -266,3 +272,77 @@ class SafetensorRemote:
if os.environ.get("HF_TOKEN"): if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
return headers return headers
@dataclass
class LocalTensorRange:
filename: Path
offset: int
size: int
@dataclass
class LocalTensor:
dtype: str
shape: tuple[int, ...]
data_range: LocalTensorRange
def mmap_bytes(self) -> np.ndarray:
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
class SafetensorsLocal:
"""
Read a safetensors file from the local filesystem.
Custom parsing gives a bit more control over the memory usage.
The official safetensors library doesn't expose file ranges.
"""
ALIGNMENT = 8 # bytes
tensors: dict[str, LocalTensor]
def __init__(self, filename: Path):
with open(filename, "rb") as f:
metadata_length = int.from_bytes(f.read(8), byteorder='little')
file_size = os.stat(filename).st_size
if file_size < 8 + metadata_length:
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
metadata_str = f.read(metadata_length).decode('utf-8')
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
data_start_offset = f.tell()
alignment = self.ALIGNMENT
if data_start_offset % alignment != 0:
data_start_offset += alignment - (data_start_offset % alignment)
tensors: dict[str, LocalTensor] = {}
for name, meta in metadata.items():
if name == "__metadata__":
# ignore metadata, it's not a tensor
continue
tensors[name] = LocalTensor(
dtype=meta["dtype"],
shape=tuple(meta["shape"]),
data_range=LocalTensorRange(
filename,
data_start_offset + meta["data_offsets"][0],
meta["data_offsets"][1] - meta["data_offsets"][0],
),
)
# order by name (same as default safetensors behavior)
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
def __enter__(self, *args, **kwargs):
del args, kwargs # unused
return self.tensors
def __exit__(self, *args, **kwargs):
del args, kwargs # unused