mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
convert : parse safetensors directly
This commit is contained in:
@@ -195,8 +195,7 @@ class ModelBase:
|
|||||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||||
ctx: ContextManager[Any]
|
ctx: ContextManager[Any]
|
||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
from safetensors import safe_open
|
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
|
||||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
|
||||||
else:
|
else:
|
||||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||||
|
|
||||||
@@ -205,18 +204,18 @@ class ModelBase:
|
|||||||
|
|
||||||
for name in model_part.keys():
|
for name in model_part.keys():
|
||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
|
data: gguf.utility.LocalTensor = model_part[name]
|
||||||
if self.lazy:
|
if self.lazy:
|
||||||
data = model_part.get_slice(name)
|
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_meta(data) # noqa: E731
|
||||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
|
|
||||||
else:
|
else:
|
||||||
data = model_part.get_tensor(name)
|
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
|
||||||
data_gen = lambda data=data: data # noqa: E731
|
data_gen = lambda data=data: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
|
||||||
else:
|
else:
|
||||||
data = model_part[name]
|
data_torch: Tensor = model_part[name]
|
||||||
if self.lazy:
|
if self.lazy:
|
||||||
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
|
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||||
else:
|
else:
|
||||||
data_gen = lambda data=data: data # noqa: E731
|
data_gen = lambda data=data_torch: data # noqa: E731
|
||||||
tensors[name] = data_gen
|
tensors[name] = data_gen
|
||||||
|
|
||||||
# verify tensor name presence and identify potentially missing files
|
# verify tensor name presence and identify potentially missing files
|
||||||
@@ -8860,6 +8859,16 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
||||||
return cast(torch.Tensor, lazy)
|
return cast(torch.Tensor, lazy)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_safetensors_meta(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
||||||
|
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
||||||
|
dtype = cls._dtype_str_map[tensor.dtype]
|
||||||
|
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
|
||||||
|
dtype = cls._dtype_str_map[t.dtype]
|
||||||
|
shape = t.shape
|
||||||
|
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
|
||||||
|
return cast(torch.Tensor, lazy)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
||||||
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||||
@@ -266,3 +268,76 @@ class SafetensorRemote:
|
|||||||
if os.environ.get("HF_TOKEN"):
|
if os.environ.get("HF_TOKEN"):
|
||||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalTensorRange:
|
||||||
|
filename: Path
|
||||||
|
offset: int
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalTensor:
|
||||||
|
dtype: str
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
data_range: LocalTensorRange
|
||||||
|
|
||||||
|
def mmap_bytes(self) -> np.ndarray:
|
||||||
|
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsLocal:
|
||||||
|
"""
|
||||||
|
Read a safetensors file from the local filesystem.
|
||||||
|
|
||||||
|
Custom parsing gives a bit more control over the memory usage.
|
||||||
|
The official safetensors library doesn't expose file ranges.
|
||||||
|
"""
|
||||||
|
ALIGNMENT = 8 # bytes
|
||||||
|
|
||||||
|
tensors: dict[str, LocalTensor]
|
||||||
|
|
||||||
|
def __init__(self, filename: Path):
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
metadata_length = int.from_bytes(f.read(8), byteorder='little')
|
||||||
|
file_size = os.stat(filename).st_size
|
||||||
|
if file_size < 8 + metadata_length:
|
||||||
|
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
|
||||||
|
|
||||||
|
metadata_str = f.read(metadata_length).decode('utf-8')
|
||||||
|
try:
|
||||||
|
metadata = json.loads(metadata_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
|
||||||
|
|
||||||
|
data_start_offset = f.tell()
|
||||||
|
alignment = self.ALIGNMENT
|
||||||
|
if data_start_offset % alignment != 0:
|
||||||
|
data_start_offset += alignment - (data_start_offset % alignment)
|
||||||
|
|
||||||
|
tensors: dict[str, LocalTensor] = {}
|
||||||
|
for name, meta in metadata.items():
|
||||||
|
if name == "__metadata__":
|
||||||
|
# ignore metadata, it's not a tensor
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensors[name] = LocalTensor(
|
||||||
|
dtype=meta["dtype"],
|
||||||
|
shape=tuple(meta["shape"]),
|
||||||
|
data_range=LocalTensorRange(
|
||||||
|
filename,
|
||||||
|
data_start_offset + meta["data_offsets"][0],
|
||||||
|
meta["data_offsets"][1] - meta["data_offsets"][0],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# order by offset
|
||||||
|
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[1].data_range.offset))
|
||||||
|
|
||||||
|
def __enter__(self, *args, **kwargs):
|
||||||
|
del args, kwargs # unused
|
||||||
|
return self.tensors
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
del args, kwargs # unused
|
||||||
|
|||||||
Reference in New Issue
Block a user