mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-30 08:42:00 +00:00
gguf-py : order safetensors tensors by name
Applies to both local and remote safetensors custom parsing. This matches the behavior of the official safetensors implementation. * convert : rename from_safetensors_meta to from_local_tensor For consistency with from_remote_tensor
This commit is contained in:
@@ -206,7 +206,7 @@ class ModelBase:
|
|||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
data: gguf.utility.LocalTensor = model_part[name]
|
data: gguf.utility.LocalTensor = model_part[name]
|
||||||
if self.lazy:
|
if self.lazy:
|
||||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_meta(data) # noqa: E731
|
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
|
||||||
else:
|
else:
|
||||||
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
|
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
|
||||||
data_gen = lambda data=data: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
|
data_gen = lambda data=data: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
|
||||||
@@ -8860,7 +8860,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||||||
return cast(torch.Tensor, lazy)
|
return cast(torch.Tensor, lazy)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_safetensors_meta(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
||||||
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
||||||
dtype = cls._dtype_str_map[tensor.dtype]
|
dtype = cls._dtype_str_map[tensor.dtype]
|
||||||
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
|
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
|
||||||
|
|||||||
@@ -179,6 +179,10 @@ class SafetensorRemote:
|
|||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||||
|
|
||||||
|
# order by name (same as default safetensors behavior)
|
||||||
|
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
||||||
|
res = dict(sorted(res.items(), key=lambda t: t[0]))
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -332,8 +336,9 @@ class SafetensorsLocal:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# order by offset
|
# order by name (same as default safetensors behavior)
|
||||||
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[1].data_range.offset))
|
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
||||||
|
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
|
||||||
|
|
||||||
def __enter__(self, *args, **kwargs):
|
def __enter__(self, *args, **kwargs):
|
||||||
del args, kwargs # unused
|
del args, kwargs # unused
|
||||||
|
|||||||
Reference in New Issue
Block a user