convert : parse safetensors directly

This commit is contained in:
Francis Couture-Harpin
2025-08-29 11:49:09 -04:00
parent 0d5cfed596
commit ca8f736fe4
2 changed files with 93 additions and 9 deletions

View File

@@ -195,8 +195,7 @@ class ModelBase:
logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if is_safetensors:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
@@ -205,18 +204,18 @@ class ModelBase:
for name in model_part.keys():
if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
if self.lazy:
data = model_part.get_slice(name)
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_meta(data) # noqa: E731
else:
data = model_part.get_tensor(name)
data_gen = lambda data=data: data # noqa: E731
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else:
data = model_part[name]
data_torch: Tensor = model_part[name]
if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data: data # noqa: E731
data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen
# verify tensor name presence and identify potentially missing files
@@ -8860,6 +8859,16 @@ class LazyTorchTensor(gguf.LazyBase):
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod
def from_safetensors_meta(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
return cast(torch.Tensor, lazy)
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype]