mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	convert-hf : support bfloat16 conversion (#7158)
* convert-hf : support bfloat16 conversion * gguf-py : flake8 fixes * convert-hf : add missing space after comma * convert-hf : get bit-exact same output as ./quantize The quantization version was missing. * convert-hf : don't round bf16 NANs * convert-hf : save some memory with np.int16 intermediate bf16 weights * convert-hf : more closely match llama.cpp with which weights to keep in f32 * convert-hf : add --outtype auto-f16 A reason for this to exist is for model quantizers who want an initial GGUF with the most fidelity to the original model while still using a 16-bit float type instead of 32-bit floats. * convert-hf : remove a semicolon because flake8 doesn't like it It's a reflex from when programming in C/C++, I guess. * convert-hf : support outtype templating in outfile name * convert-hf : rename --outtype auto-f16 to --outtype auto
This commit is contained in:
		| @@ -12,7 +12,7 @@ import sys | |||||||
| from enum import IntEnum | from enum import IntEnum | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
| from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload | from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| @@ -48,7 +48,6 @@ class Model: | |||||||
|  |  | ||||||
|     dir_model: Path |     dir_model: Path | ||||||
|     ftype: int |     ftype: int | ||||||
|     fname_out: Path |  | ||||||
|     is_big_endian: bool |     is_big_endian: bool | ||||||
|     endianess: gguf.GGUFEndian |     endianess: gguf.GGUFEndian | ||||||
|     use_temp_file: bool |     use_temp_file: bool | ||||||
| @@ -56,20 +55,20 @@ class Model: | |||||||
|     part_names: list[str] |     part_names: list[str] | ||||||
|     is_safetensors: bool |     is_safetensors: bool | ||||||
|     hparams: dict[str, Any] |     hparams: dict[str, Any] | ||||||
|     gguf_writer: gguf.GGUFWriter |  | ||||||
|     block_count: int |     block_count: int | ||||||
|     tensor_map: gguf.TensorNameMap |     tensor_map: gguf.TensorNameMap | ||||||
|     tensor_names: set[str] | None |     tensor_names: set[str] | None | ||||||
|  |     fname_out: Path | ||||||
|  |     gguf_writer: gguf.GGUFWriter | ||||||
|  |  | ||||||
|     # subclasses should define this! |     # subclasses should define this! | ||||||
|     model_arch: gguf.MODEL_ARCH |     model_arch: gguf.MODEL_ARCH | ||||||
|  |  | ||||||
|     def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): |     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): | ||||||
|         if self.__class__ == Model: |         if type(self) is Model: | ||||||
|             raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated") |             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") | ||||||
|         self.dir_model = dir_model |         self.dir_model = dir_model | ||||||
|         self.ftype = ftype |         self.ftype = ftype | ||||||
|         self.fname_out = fname_out |  | ||||||
|         self.is_big_endian = is_big_endian |         self.is_big_endian = is_big_endian | ||||||
|         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE |         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE | ||||||
|         self.use_temp_file = use_temp_file |         self.use_temp_file = use_temp_file | ||||||
| @@ -79,10 +78,23 @@ class Model: | |||||||
|         if not self.is_safetensors: |         if not self.is_safetensors: | ||||||
|             self.part_names = Model.get_model_part_names(self.dir_model, ".bin") |             self.part_names = Model.get_model_part_names(self.dir_model, ".bin") | ||||||
|         self.hparams = Model.load_hparams(self.dir_model) |         self.hparams = Model.load_hparams(self.dir_model) | ||||||
|         self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) |  | ||||||
|         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) |         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) | ||||||
|         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) |         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) | ||||||
|         self.tensor_names = None |         self.tensor_names = None | ||||||
|  |         if self.ftype == gguf.LlamaFileType.GUESSED: | ||||||
|  |             # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. | ||||||
|  |             _, first_tensor = next(self.get_tensors()) | ||||||
|  |             if first_tensor.dtype == torch.float16: | ||||||
|  |                 logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") | ||||||
|  |                 self.ftype = gguf.LlamaFileType.MOSTLY_F16 | ||||||
|  |             else: | ||||||
|  |                 logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") | ||||||
|  |                 self.ftype = gguf.LlamaFileType.MOSTLY_BF16 | ||||||
|  |         ftype_up: str = self.ftype.name.partition("_")[2].upper() | ||||||
|  |         ftype_lw: str = ftype_up.lower() | ||||||
|  |         # allow templating the file name with the output ftype, useful with the "auto" ftype | ||||||
|  |         self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) | ||||||
|  |         self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def __init_subclass__(cls): |     def __init_subclass__(cls): | ||||||
| @@ -142,14 +154,27 @@ class Model: | |||||||
|             raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") |             raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") | ||||||
|  |  | ||||||
|     def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: |     def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: | ||||||
|         name: str = gguf.TENSOR_NAMES[key] |  | ||||||
|         if key not in gguf.MODEL_TENSORS[self.model_arch]: |         if key not in gguf.MODEL_TENSORS[self.model_arch]: | ||||||
|             raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") |             raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") | ||||||
|  |         name: str = gguf.TENSOR_NAMES[key] | ||||||
|         if "{bid}" in name: |         if "{bid}" in name: | ||||||
|             assert bid is not None |             assert bid is not None | ||||||
|             name = name.format(bid=bid) |             name = name.format(bid=bid) | ||||||
|         return name + suffix |         return name + suffix | ||||||
|  |  | ||||||
|  |     def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool: | ||||||
|  |         if key not in gguf.MODEL_TENSORS[self.model_arch]: | ||||||
|  |             return False | ||||||
|  |         key_name: str = gguf.TENSOR_NAMES[key] | ||||||
|  |         if "{bid}" in key_name: | ||||||
|  |             if bid is None: | ||||||
|  |                 return False | ||||||
|  |             key_name = key_name.format(bid=bid) | ||||||
|  |         else: | ||||||
|  |             if bid is not None: | ||||||
|  |                 return False | ||||||
|  |         return name == (key_name + suffix) | ||||||
|  |  | ||||||
|     def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: |     def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: | ||||||
|         new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) |         new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) | ||||||
|         if new_name is None: |         if new_name is None: | ||||||
| @@ -215,6 +240,23 @@ class Model: | |||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def write_tensors(self): |     def write_tensors(self): | ||||||
|  |         # same as ggml_compute_fp32_to_bf16 in ggml-impl.h | ||||||
|  |         def np_fp32_to_bf16(n: np.ndarray): | ||||||
|  |             # force nan to quiet | ||||||
|  |             n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n) | ||||||
|  |             # flush subnormals to zero | ||||||
|  |             n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n) | ||||||
|  |             # round to nearest even | ||||||
|  |             n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 | ||||||
|  |             return n.astype(np.int16) | ||||||
|  |  | ||||||
|  |         # Doing this row-wise is much, much faster than element-wise, hence the signature | ||||||
|  |         v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)") | ||||||
|  |         if self.lazy: | ||||||
|  |             # TODO: find a way to implicitly wrap np.vectorize functions | ||||||
|  |             # NOTE: the type is changed to reflect otypes passed to np.vectorize above | ||||||
|  |             v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16) | ||||||
|  |  | ||||||
|         max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") |         max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") | ||||||
|  |  | ||||||
|         for name, data_torch in self.get_tensors(): |         for name, data_torch in self.get_tensors(): | ||||||
| @@ -239,35 +281,60 @@ class Model: | |||||||
|                 data: np.ndarray = data  # type hint |                 data: np.ndarray = data  # type hint | ||||||
|                 n_dims = len(data.shape) |                 n_dims = len(data.shape) | ||||||
|                 data_dtype = data.dtype |                 data_dtype = data.dtype | ||||||
|  |                 data_qtype: gguf.GGMLQuantizationType | None = None | ||||||
|                 # if f32 desired, convert any float16 to float32 |  | ||||||
|                 if self.ftype == 0 and data_dtype == np.float16: |  | ||||||
|                     data = data.astype(np.float32) |  | ||||||
|  |  | ||||||
|                 # when both are True, f32 should win |                 # when both are True, f32 should win | ||||||
|                 extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) |                 extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) | ||||||
|                 extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) |                 extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) | ||||||
|  |  | ||||||
|                 # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors |                 # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors | ||||||
|                 extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight") |                 # Conditions should closely match those in llama_model_quantize_internal in llama.cpp | ||||||
|  |                 extra_f32 = any(cond for cond in ( | ||||||
|  |                     extra_f32, | ||||||
|  |                     n_dims == 1, | ||||||
|  |                     new_name.endswith("_norm.weight"), | ||||||
|  |                 )) | ||||||
|  |  | ||||||
|  |                 # Some tensor types are always in float32 | ||||||
|  |                 extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( | ||||||
|  |                     gguf.MODEL_TENSOR.FFN_GATE_INP, | ||||||
|  |                     gguf.MODEL_TENSOR.POS_EMBD, | ||||||
|  |                     gguf.MODEL_TENSOR.TOKEN_TYPES, | ||||||
|  |                 )) | ||||||
|  |  | ||||||
|                 # if f16 desired, convert any float32 2-dim weight tensors to float16 |                 # if f16 desired, convert any float32 2-dim weight tensors to float16 | ||||||
|                 extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2) |                 extra_f16 = any(cond for cond in ( | ||||||
|  |                     extra_f16, | ||||||
|  |                     (name.endswith(".weight") and n_dims >= 2), | ||||||
|  |                 )) | ||||||
|  |  | ||||||
|                 # when both extra_f32 and extra_f16 are False, convert to float32 by default |                 if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: | ||||||
|                 if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16): |                     if self.ftype == gguf.LlamaFileType.MOSTLY_F16: | ||||||
|                     data = data.astype(np.float32) |                         if data_dtype != np.float16: | ||||||
|  |                             data = data.astype(np.float16) | ||||||
|  |                         data_qtype = gguf.GGMLQuantizationType.F16 | ||||||
|  |  | ||||||
|                 if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32: |                     elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: | ||||||
|                     data = data.astype(np.float16) |                         if data_dtype != np.float32: | ||||||
|  |                             data = data.astype(np.float32) | ||||||
|  |                         data = v_fp32_to_bf16(data.view(np.int32)) | ||||||
|  |                         assert data.dtype == np.int16 | ||||||
|  |                         data_qtype = gguf.GGMLQuantizationType.BF16 | ||||||
|  |  | ||||||
|  |                 else:  # by default, convert to float32 | ||||||
|  |                     if data_dtype != np.float32: | ||||||
|  |                         data = data.astype(np.float32) | ||||||
|  |                     data_qtype = gguf.GGMLQuantizationType.F32 | ||||||
|  |  | ||||||
|  |                 assert data_qtype is not None | ||||||
|  |  | ||||||
|                 # reverse shape to make it similar to the internal ggml dimension order |                 # reverse shape to make it similar to the internal ggml dimension order | ||||||
|                 shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" |                 shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" | ||||||
|  |  | ||||||
|                 # n_dims is implicit in the shape |                 # n_dims is implicit in the shape | ||||||
|                 logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}") |                 logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") | ||||||
|  |  | ||||||
|                 self.gguf_writer.add_tensor(new_name, data) |                 self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) | ||||||
|  |  | ||||||
|     def write(self): |     def write(self): | ||||||
|         self.write_tensors() |         self.write_tensors() | ||||||
| @@ -2044,12 +2111,6 @@ class BertModel(Model): | |||||||
|  |  | ||||||
|         return [(self.map_tensor_name(name), data_torch)] |         return [(self.map_tensor_name(name), data_torch)] | ||||||
|  |  | ||||||
|     def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: |  | ||||||
|         del new_name, bid, n_dims  # unused |  | ||||||
|  |  | ||||||
|         # not used with get_rows, must be F32 |  | ||||||
|         return name == "embeddings.token_type_embeddings.weight" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @Model.register("NomicBertModel") | @Model.register("NomicBertModel") | ||||||
| class NomicBertModel(BertModel): | class NomicBertModel(BertModel): | ||||||
| @@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel): | |||||||
|  |  | ||||||
|  |  | ||||||
| # tree of lazy tensors | # tree of lazy tensors | ||||||
| class LazyTorchTensor: | class LazyTorchTensor(gguf.LazyBase): | ||||||
|     _meta: Tensor |     _tensor_type = torch.Tensor | ||||||
|     _data: Tensor | None |     # to keep the type-checker happy | ||||||
|     _args: tuple |     dtype: torch.dtype | ||||||
|     _func: Callable[[tuple], Tensor] | None |     shape: torch.Size | ||||||
|  |  | ||||||
|     def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None): |  | ||||||
|         self._meta = meta |  | ||||||
|         self._data = data |  | ||||||
|         self._args = args |  | ||||||
|         self._func = func |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: |  | ||||||
|         # TODO: dict and set |  | ||||||
|         if isinstance(o, (list, tuple)): |  | ||||||
|             L = [] |  | ||||||
|             for item in o: |  | ||||||
|                 L.append(LazyTorchTensor._recurse_apply(item, fn)) |  | ||||||
|             if isinstance(o, tuple): |  | ||||||
|                 L = tuple(L) |  | ||||||
|             return L |  | ||||||
|         elif isinstance(o, LazyTorchTensor): |  | ||||||
|             return fn(o) |  | ||||||
|         else: |  | ||||||
|             return o |  | ||||||
|  |  | ||||||
|     def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]: |  | ||||||
|         def wrapped_fn(*args, **kwargs): |  | ||||||
|             if kwargs is None: |  | ||||||
|                 kwargs = {} |  | ||||||
|             args = ((self,) if use_self else ()) + args |  | ||||||
|  |  | ||||||
|             meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta) |  | ||||||
|  |  | ||||||
|             return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs)) |  | ||||||
|         return wrapped_fn |  | ||||||
|  |  | ||||||
|     def __getattr__(self, __name: str) -> Any: |  | ||||||
|         meta_attr = getattr(self._meta, __name) |  | ||||||
|         if callable(meta_attr): |  | ||||||
|             return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True) |  | ||||||
|         elif isinstance(meta_attr, torch.Tensor): |  | ||||||
|             # for things like self.T |  | ||||||
|             return self._wrap_fn(lambda s: getattr(s, __name))(self) |  | ||||||
|         else: |  | ||||||
|             return meta_attr |  | ||||||
|  |  | ||||||
|  |     # only used when converting a torch.Tensor to a np.ndarray | ||||||
|     _dtype_map: dict[torch.dtype, type] = { |     _dtype_map: dict[torch.dtype, type] = { | ||||||
|         torch.float16: np.float16, |         torch.float16: np.float16, | ||||||
|         torch.float32: np.float32, |         torch.float32: np.float32, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     def numpy(self) -> gguf.LazyTensor: |     def numpy(self) -> gguf.LazyNumpyTensor: | ||||||
|         dtype = self._dtype_map[self.dtype] |         dtype = self._dtype_map[self.dtype] | ||||||
|         return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape) |         return gguf.LazyNumpyTensor( | ||||||
|  |             meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)), | ||||||
|  |             lazy=self._lazy, | ||||||
|  |             args=(self,), | ||||||
|  |             func=(lambda s: s[0].numpy()) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @overload |     @classmethod | ||||||
|     @staticmethod |     def eager_to_meta(cls, t: Tensor) -> Tensor: | ||||||
|     def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ... |         if t.is_meta: | ||||||
|  |  | ||||||
|     @overload |  | ||||||
|     @staticmethod |  | ||||||
|     def to_eager(t: tuple) -> tuple: ... |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def to_eager(t: Any) -> Any: |  | ||||||
|         def simple_to_eager(_t: LazyTorchTensor) -> Tensor: |  | ||||||
|             # wake up the lazy tensor |  | ||||||
|             if _t._data is None and _t._func is not None: |  | ||||||
|                 # recurse into its arguments |  | ||||||
|                 _t._args = LazyTorchTensor.to_eager(_t._args) |  | ||||||
|                 _t._data = _t._func(_t._args) |  | ||||||
|             if _t._data is not None: |  | ||||||
|                 return _t._data |  | ||||||
|             else: |  | ||||||
|                 raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}") |  | ||||||
|  |  | ||||||
|         # recurse into lists and/or tuples, keeping their structure |  | ||||||
|         return LazyTorchTensor._recurse_apply(t, simple_to_eager) |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def from_eager(t: Tensor) -> Tensor: |  | ||||||
|         if (t.__class__ == LazyTorchTensor): |  | ||||||
|             return t |             return t | ||||||
|         return LazyTorchTensor(meta=t.detach().to("meta"), data=t)  # type: ignore |         return t.detach().to("meta") | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor: | ||||||
|  |         m = m.detach() | ||||||
|  |         if not m.is_meta: | ||||||
|  |             m = m.to("meta") | ||||||
|  |         m.dtype = dtype | ||||||
|  |         return m | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def __torch_function__(cls, func, types, args=(), kwargs=None): |     def __torch_function__(cls, func, types, args=(), kwargs=None): | ||||||
| @@ -2435,28 +2444,8 @@ class LazyTorchTensor: | |||||||
|  |  | ||||||
|         if func is torch.Tensor.numpy: |         if func is torch.Tensor.numpy: | ||||||
|             return args[0].numpy() |             return args[0].numpy() | ||||||
|         if func is torch.equal: |  | ||||||
|             eager_args = LazyTorchTensor.to_eager(args) |  | ||||||
|             return func(*eager_args, **kwargs) |  | ||||||
|  |  | ||||||
|         return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs) |         return LazyTorchTensor._wrap_fn(func)(*args, **kwargs) | ||||||
|  |  | ||||||
|     # special methods bypass __getattr__, so they need to be added manually |  | ||||||
|     # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup |  | ||||||
|     # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used |  | ||||||
|     #       as self._meta is currently used), because then the following |  | ||||||
|     #       operations would by default not be wrapped, and so not propagated |  | ||||||
|     #       when the tensor is made eager. |  | ||||||
|     #       It's better to get non-silent errors for not-yet-supported operators. |  | ||||||
|     # TODO: add more when needed to avoid clutter, or find a more concise way |  | ||||||
|     def __neg__(self, *args):  # mamba |  | ||||||
|         return self._wrap_fn(torch.Tensor.__neg__)(self, *args) |  | ||||||
|  |  | ||||||
|     def __add__(self, *args):  # gemma |  | ||||||
|         return self._wrap_fn(torch.Tensor.__add__)(self, *args) |  | ||||||
|  |  | ||||||
|     def __getitem__(self, *args):  # bloom falcon refact internlm2 |  | ||||||
|         return self._wrap_fn(torch.Tensor.__getitem__)(self, *args) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_args() -> argparse.Namespace: | def parse_args() -> argparse.Namespace: | ||||||
| @@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace: | |||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--outfile", type=Path, |         "--outfile", type=Path, | ||||||
|         help="path to write to; default: based on input", |         help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--outtype", type=str, choices=["f32", "f16"], default="f16", |         "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16", | ||||||
|         help="output format - use f32 for float32, f16 for float16", |         help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--bigendian", action="store_true", |         "--bigendian", action="store_true", | ||||||
| @@ -2530,16 +2519,18 @@ def main() -> None: | |||||||
|         logger.error(f'Error: {args.model} is not a directory') |         logger.error(f'Error: {args.model} is not a directory') | ||||||
|         sys.exit(1) |         sys.exit(1) | ||||||
|  |  | ||||||
|     ftype_map = { |     ftype_map: dict[str, gguf.LlamaFileType] = { | ||||||
|         "f32": gguf.GGMLQuantizationType.F32, |         "f32": gguf.LlamaFileType.ALL_F32, | ||||||
|         "f16": gguf.GGMLQuantizationType.F16, |         "f16": gguf.LlamaFileType.MOSTLY_F16, | ||||||
|  |         "bf16": gguf.LlamaFileType.MOSTLY_BF16, | ||||||
|  |         "auto": gguf.LlamaFileType.GUESSED, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if args.outfile is not None: |     if args.outfile is not None: | ||||||
|         fname_out = args.outfile |         fname_out = args.outfile | ||||||
|     else: |     else: | ||||||
|         # output in the same directory as the model by default |         # output in the same directory as the model by default | ||||||
|         fname_out = dir_model / f'ggml-model-{args.outtype}.gguf' |         fname_out = dir_model / 'ggml-model-{ftype}.gguf' | ||||||
|  |  | ||||||
|     logger.info(f"Loading model: {dir_model.name}") |     logger.info(f"Loading model: {dir_model.name}") | ||||||
|  |  | ||||||
| @@ -2555,14 +2546,16 @@ def main() -> None: | |||||||
|         logger.info("Set model tokenizer") |         logger.info("Set model tokenizer") | ||||||
|         model_instance.set_vocab() |         model_instance.set_vocab() | ||||||
|  |  | ||||||
|  |         model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) | ||||||
|  |  | ||||||
|         if args.vocab_only: |         if args.vocab_only: | ||||||
|             logger.info(f"Exporting model vocab to '{fname_out}'") |             logger.info(f"Exporting model vocab to '{model_instance.fname_out}'") | ||||||
|             model_instance.write_vocab() |             model_instance.write_vocab() | ||||||
|         else: |         else: | ||||||
|             logger.info(f"Exporting model to '{fname_out}'") |             logger.info(f"Exporting model to '{model_instance.fname_out}'") | ||||||
|             model_instance.write() |             model_instance.write() | ||||||
|  |  | ||||||
|         logger.info(f"Model successfully exported to '{fname_out}'") |         logger.info(f"Model successfully exported to '{model_instance.fname_out}'") | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| from .constants import * | from .constants import * | ||||||
|  | from .lazy import * | ||||||
| from .gguf_reader import * | from .gguf_reader import * | ||||||
| from .gguf_writer import * | from .gguf_writer import * | ||||||
| from .tensor_mapping import * | from .tensor_mapping import * | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ from typing import Any | |||||||
| GGUF_MAGIC             = 0x46554747  # "GGUF" | GGUF_MAGIC             = 0x46554747  # "GGUF" | ||||||
| GGUF_VERSION           = 3 | GGUF_VERSION           = 3 | ||||||
| GGUF_DEFAULT_ALIGNMENT = 32 | GGUF_DEFAULT_ALIGNMENT = 32 | ||||||
|  | GGML_QUANT_VERSION     = 2  # GGML_QNT_VERSION from ggml.h | ||||||
|  |  | ||||||
| # | # | ||||||
| # metadata keys | # metadata keys | ||||||
| @@ -838,6 +839,49 @@ class GGMLQuantizationType(IntEnum): | |||||||
|     BF16    = 30 |     BF16    = 30 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # TODO: add GGMLFileType from ggml_ftype in ggml.h | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # from llama_ftype in llama.h | ||||||
|  | # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. | ||||||
|  | class LlamaFileType(IntEnum): | ||||||
|  |     ALL_F32              = 0 | ||||||
|  |     MOSTLY_F16           = 1   # except 1d tensors | ||||||
|  |     MOSTLY_Q4_0          = 2   # except 1d tensors | ||||||
|  |     MOSTLY_Q4_1          = 3   # except 1d tensors | ||||||
|  |     MOSTLY_Q4_1_SOME_F16 = 4   # tok_embeddings.weight and output.weight are F16 | ||||||
|  |     # MOSTLY_Q4_2        = 5   # support has been removed | ||||||
|  |     # MOSTLY_Q4_3        = 6   # support has been removed | ||||||
|  |     MOSTLY_Q8_0          = 7   # except 1d tensors | ||||||
|  |     MOSTLY_Q5_0          = 8   # except 1d tensors | ||||||
|  |     MOSTLY_Q5_1          = 9   # except 1d tensors | ||||||
|  |     MOSTLY_Q2_K          = 10  # except 1d tensors | ||||||
|  |     MOSTLY_Q3_K_S        = 11  # except 1d tensors | ||||||
|  |     MOSTLY_Q3_K_M        = 12  # except 1d tensors | ||||||
|  |     MOSTLY_Q3_K_L        = 13  # except 1d tensors | ||||||
|  |     MOSTLY_Q4_K_S        = 14  # except 1d tensors | ||||||
|  |     MOSTLY_Q4_K_M        = 15  # except 1d tensors | ||||||
|  |     MOSTLY_Q5_K_S        = 16  # except 1d tensors | ||||||
|  |     MOSTLY_Q5_K_M        = 17  # except 1d tensors | ||||||
|  |     MOSTLY_Q6_K          = 18  # except 1d tensors | ||||||
|  |     MOSTLY_IQ2_XXS       = 19  # except 1d tensors | ||||||
|  |     MOSTLY_IQ2_XS        = 20  # except 1d tensors | ||||||
|  |     MOSTLY_Q2_K_S        = 21  # except 1d tensors | ||||||
|  |     MOSTLY_IQ3_XS        = 22  # except 1d tensors | ||||||
|  |     MOSTLY_IQ3_XXS       = 23  # except 1d tensors | ||||||
|  |     MOSTLY_IQ1_S         = 24  # except 1d tensors | ||||||
|  |     MOSTLY_IQ4_NL        = 25  # except 1d tensors | ||||||
|  |     MOSTLY_IQ3_S         = 26  # except 1d tensors | ||||||
|  |     MOSTLY_IQ3_M         = 27  # except 1d tensors | ||||||
|  |     MOSTLY_IQ2_S         = 28  # except 1d tensors | ||||||
|  |     MOSTLY_IQ2_M         = 29  # except 1d tensors | ||||||
|  |     MOSTLY_IQ4_XS        = 30  # except 1d tensors | ||||||
|  |     MOSTLY_IQ1_M         = 31  # except 1d tensors | ||||||
|  |     MOSTLY_BF16          = 32  # except 1d tensors | ||||||
|  |  | ||||||
|  |     GUESSED              = 1024  # not specified in the model file | ||||||
|  |  | ||||||
|  |  | ||||||
| class GGUFEndian(IntEnum): | class GGUFEndian(IntEnum): | ||||||
|     LITTLE = 0 |     LITTLE = 0 | ||||||
|     BIG = 1 |     BIG = 1 | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ import struct | |||||||
| import tempfile | import tempfile | ||||||
| from enum import Enum, auto | from enum import Enum, auto | ||||||
| from io import BufferedWriter | from io import BufferedWriter | ||||||
| from typing import IO, Any, Callable, Sequence, Mapping | from typing import IO, Any, Sequence, Mapping | ||||||
| from string import ascii_letters, digits | from string import ascii_letters, digits | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -28,47 +28,6 @@ from .constants import ( | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| class LazyTensor: |  | ||||||
|     data: Callable[[], np.ndarray[Any, Any]] |  | ||||||
|     # to avoid too deep recursion |  | ||||||
|     functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]] |  | ||||||
|     dtype: np.dtype[Any] |  | ||||||
|     shape: tuple[int, ...] |  | ||||||
|  |  | ||||||
|     def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]): |  | ||||||
|         self.data = data |  | ||||||
|         self.functions = [] |  | ||||||
|         self.dtype = np.dtype(dtype) |  | ||||||
|         self.shape = shape |  | ||||||
|  |  | ||||||
|     def astype(self, dtype: type, **kwargs) -> LazyTensor: |  | ||||||
|         self.functions.append(lambda n: n.astype(dtype, **kwargs)) |  | ||||||
|         self.dtype = np.dtype(dtype) |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def nbytes(self) -> int: |  | ||||||
|         size = 1 |  | ||||||
|         for n in self.shape: |  | ||||||
|             size *= n |  | ||||||
|         return size * self.dtype.itemsize |  | ||||||
|  |  | ||||||
|     def tofile(self, *args, **kwargs) -> None: |  | ||||||
|         data = self.data() |  | ||||||
|         for f in self.functions: |  | ||||||
|             data = f(data) |  | ||||||
|         assert data.shape == self.shape |  | ||||||
|         assert data.dtype == self.dtype |  | ||||||
|         assert data.nbytes == self.nbytes |  | ||||||
|         self.functions = [] |  | ||||||
|         self.data = lambda: data |  | ||||||
|         data.tofile(*args, **kwargs) |  | ||||||
|  |  | ||||||
|     def byteswap(self, *args, **kwargs) -> LazyTensor: |  | ||||||
|         self.functions.append(lambda n: n.byteswap(*args, **kwargs)) |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class WriterState(Enum): | class WriterState(Enum): | ||||||
|     EMPTY   = auto() |     EMPTY   = auto() | ||||||
|     HEADER  = auto() |     HEADER  = auto() | ||||||
| @@ -79,7 +38,7 @@ class WriterState(Enum): | |||||||
| class GGUFWriter: | class GGUFWriter: | ||||||
|     fout: BufferedWriter |     fout: BufferedWriter | ||||||
|     temp_file: tempfile.SpooledTemporaryFile[bytes] | None |     temp_file: tempfile.SpooledTemporaryFile[bytes] | None | ||||||
|     tensors: list[np.ndarray[Any, Any] | LazyTensor] |     tensors: list[np.ndarray[Any, Any]] | ||||||
|     _simple_value_packing = { |     _simple_value_packing = { | ||||||
|         GGUFValueType.UINT8:   "B", |         GGUFValueType.UINT8:   "B", | ||||||
|         GGUFValueType.INT8:    "b", |         GGUFValueType.INT8:    "b", | ||||||
| @@ -278,7 +237,7 @@ class GGUFWriter: | |||||||
|         self.ti_data_count += 1 |         self.ti_data_count += 1 | ||||||
|  |  | ||||||
|     def add_tensor( |     def add_tensor( | ||||||
|         self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None, |         self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, | ||||||
|         raw_dtype: GGMLQuantizationType | None = None, |         raw_dtype: GGMLQuantizationType | None = None, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         if self.endianess == GGUFEndian.BIG: |         if self.endianess == GGUFEndian.BIG: | ||||||
| @@ -303,7 +262,7 @@ class GGUFWriter: | |||||||
|         if pad != 0: |         if pad != 0: | ||||||
|             fp.write(bytes([0] * pad)) |             fp.write(bytes([0] * pad)) | ||||||
|  |  | ||||||
|     def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None: |     def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: | ||||||
|         if self.state is not WriterState.TI_DATA: |         if self.state is not WriterState.TI_DATA: | ||||||
|             raise ValueError(f'Expected output file to contain tensor info, got {self.state}') |             raise ValueError(f'Expected output file to contain tensor info, got {self.state}') | ||||||
|  |  | ||||||
| @@ -391,7 +350,7 @@ class GGUFWriter: | |||||||
|     def add_name(self, name: str) -> None: |     def add_name(self, name: str) -> None: | ||||||
|         self.add_string(Keys.General.NAME, name) |         self.add_string(Keys.General.NAME, name) | ||||||
|  |  | ||||||
|     def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None: |     def add_quantization_version(self, quantization_version: int) -> None: | ||||||
|         self.add_uint32( |         self.add_uint32( | ||||||
|             Keys.General.QUANTIZATION_VERSION, quantization_version) |             Keys.General.QUANTIZATION_VERSION, quantization_version) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										225
									
								
								gguf-py/gguf/lazy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								gguf-py/gguf/lazy.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  | from abc import ABC, ABCMeta, abstractmethod | ||||||
|  |  | ||||||
|  | import logging | ||||||
|  | from typing import Any, Callable | ||||||
|  | from collections import deque | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | from numpy.typing import DTypeLike | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LazyMeta(ABCMeta): | ||||||
|  |  | ||||||
|  |     def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs): | ||||||
|  |         def __getattr__(self, __name: str) -> Any: | ||||||
|  |             meta_attr = getattr(self._meta, __name) | ||||||
|  |             if callable(meta_attr): | ||||||
|  |                 return type(self)._wrap_fn( | ||||||
|  |                     (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)), | ||||||
|  |                     use_self=self, | ||||||
|  |                 ) | ||||||
|  |             elif isinstance(meta_attr, self._tensor_type): | ||||||
|  |                 # e.g. self.T with torch.Tensor should still be wrapped | ||||||
|  |                 return type(self)._wrap_fn(lambda s: getattr(s, __name))(self) | ||||||
|  |             else: | ||||||
|  |                 # no need to wrap non-tensor properties, | ||||||
|  |                 # and they likely don't depend on the actual contents of the tensor | ||||||
|  |                 return meta_attr | ||||||
|  |  | ||||||
|  |         namespace["__getattr__"] = __getattr__ | ||||||
|  |  | ||||||
|  |         # need to make a builder for the wrapped wrapper to copy the name, | ||||||
|  |         # or else it fails with very cryptic error messages, | ||||||
|  |         # because somehow the same string would end up in every closures | ||||||
|  |         def mk_wrap(op_name: str, *, meta_noop: bool = False): | ||||||
|  |             # need to wrap the wrapper to get self | ||||||
|  |             def wrapped_special_op(self, *args, **kwargs): | ||||||
|  |                 return type(self)._wrap_fn( | ||||||
|  |                     getattr(type(self)._tensor_type, op_name), | ||||||
|  |                     meta_noop=meta_noop, | ||||||
|  |                 )(self, *args, **kwargs) | ||||||
|  |             return wrapped_special_op | ||||||
|  |  | ||||||
|  |         # special methods bypass __getattr__, so they need to be added manually | ||||||
|  |         # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup | ||||||
|  |         # NOTE: doing this from a metaclass is very convenient | ||||||
|  |         # TODO: make this even more comprehensive | ||||||
|  |         for binary_op in ( | ||||||
|  |             "lt", "le", "eq", "ne", "ge", "gt", "not" | ||||||
|  |             "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul", | ||||||
|  |             "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor", | ||||||
|  |             "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor", | ||||||
|  |             "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor", | ||||||
|  |         ): | ||||||
|  |             attr_name = f"__{binary_op}__" | ||||||
|  |             # the result of these operators usually has the same shape and dtype as the input, | ||||||
|  |             # so evaluation on the meta tensor can be skipped. | ||||||
|  |             namespace[attr_name] = mk_wrap(attr_name, meta_noop=True) | ||||||
|  |  | ||||||
|  |         for special_op in ( | ||||||
|  |             "getitem", "setitem", "len", | ||||||
|  |         ): | ||||||
|  |             attr_name = f"__{special_op}__" | ||||||
|  |             namespace[attr_name] = mk_wrap(attr_name, meta_noop=False) | ||||||
|  |  | ||||||
|  |         return super().__new__(cls, name, bases, namespace, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Tree of lazy tensors | ||||||
|  | class LazyBase(ABC, metaclass=LazyMeta): | ||||||
|  |     _tensor_type: type | ||||||
|  |     _meta: Any | ||||||
|  |     _data: Any | None | ||||||
|  |     _lazy: deque[LazyBase]  # shared within a graph, to avoid deep recursion when making eager | ||||||
|  |     _args: tuple | ||||||
|  |     _func: Callable[[tuple], Any] | None | ||||||
|  |  | ||||||
|  |     def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): | ||||||
|  |         super().__init__() | ||||||
|  |         self._meta = meta | ||||||
|  |         self._data = data | ||||||
|  |         self._lazy = lazy if lazy is not None else deque() | ||||||
|  |         self._args = args | ||||||
|  |         self._func = func | ||||||
|  |         assert self._func is not None or self._data is not None | ||||||
|  |         if self._data is None: | ||||||
|  |             self._lazy.append(self) | ||||||
|  |  | ||||||
|  |     def __init_subclass__(cls) -> None: | ||||||
|  |         if "_tensor_type" not in cls.__dict__: | ||||||
|  |             raise TypeError(f"property '_tensor_type' must be defined for {cls!r}") | ||||||
|  |         return super().__init_subclass__() | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: | ||||||
|  |         # TODO: dict and set | ||||||
|  |         if isinstance(o, (list, tuple)): | ||||||
|  |             L = [] | ||||||
|  |             for item in o: | ||||||
|  |                 L.append(LazyBase._recurse_apply(item, fn)) | ||||||
|  |             if isinstance(o, tuple): | ||||||
|  |                 L = tuple(L) | ||||||
|  |             return L | ||||||
|  |         elif isinstance(o, LazyBase): | ||||||
|  |             return fn(o) | ||||||
|  |         else: | ||||||
|  |             return o | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]: | ||||||
|  |         def wrapped_fn(*args, **kwargs): | ||||||
|  |             if kwargs is None: | ||||||
|  |                 kwargs = {} | ||||||
|  |             args = ((use_self,) if use_self is not None else ()) + args | ||||||
|  |  | ||||||
|  |             meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) | ||||||
|  |  | ||||||
|  |             if isinstance(meta_noop, bool) and not meta_noop: | ||||||
|  |                 try: | ||||||
|  |                     res = fn(*meta_args, **kwargs) | ||||||
|  |                 except NotImplementedError: | ||||||
|  |                     # running some operations on PyTorch's Meta tensors can cause this exception | ||||||
|  |                     res = None | ||||||
|  |             else: | ||||||
|  |                 # some operators don't need to actually run on the meta tensors | ||||||
|  |                 assert len(args) > 0 | ||||||
|  |                 res = args[0] | ||||||
|  |                 assert isinstance(res, cls) | ||||||
|  |                 res = res._meta | ||||||
|  |                 # allow operations to override the dtype | ||||||
|  |                 if meta_noop is not True: | ||||||
|  |                     res = cls.meta_with_dtype(res, meta_noop) | ||||||
|  |  | ||||||
|  |             if isinstance(res, cls._tensor_type): | ||||||
|  |                 def collect_replace(t: LazyBase): | ||||||
|  |                     if collect_replace.shared_lazy is None: | ||||||
|  |                         collect_replace.shared_lazy = t._lazy | ||||||
|  |                     else: | ||||||
|  |                         collect_replace.shared_lazy.extend(t._lazy) | ||||||
|  |                         t._lazy = collect_replace.shared_lazy | ||||||
|  |  | ||||||
|  |                 # emulating a static variable | ||||||
|  |                 collect_replace.shared_lazy = None | ||||||
|  |  | ||||||
|  |                 LazyBase._recurse_apply(args, collect_replace) | ||||||
|  |  | ||||||
|  |                 shared_lazy = collect_replace.shared_lazy | ||||||
|  |  | ||||||
|  |                 return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs)) | ||||||
|  |             else: | ||||||
|  |                 del res  # not needed | ||||||
|  |                 # non-tensor return likely relies on the contents of the args | ||||||
|  |                 # (e.g. the result of torch.equal) | ||||||
|  |                 eager_args = cls.to_eager(args) | ||||||
|  |                 return fn(*eager_args, **kwargs) | ||||||
|  |         return wrapped_fn | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def to_eager(cls, t: Any) -> Any: | ||||||
|  |         def simple_to_eager(_t: LazyBase) -> Any: | ||||||
|  |             def already_eager_to_eager(_t: LazyBase) -> Any: | ||||||
|  |                 assert _t._data is not None | ||||||
|  |                 return _t._data | ||||||
|  |  | ||||||
|  |             while _t._data is None: | ||||||
|  |                 lt = _t._lazy.popleft() | ||||||
|  |                 if lt._data is not None: | ||||||
|  |                     raise ValueError(f"{lt} did not belong in the lazy queue") | ||||||
|  |                 assert lt._func is not None | ||||||
|  |                 lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) | ||||||
|  |                 lt._data = lt._func(lt._args) | ||||||
|  |                 # sanity check | ||||||
|  |                 assert lt._data.dtype == lt._meta.dtype | ||||||
|  |                 assert lt._data.shape == lt._meta.shape | ||||||
|  |  | ||||||
|  |             return _t._data | ||||||
|  |  | ||||||
|  |         # recurse into lists and/or tuples, keeping their structure | ||||||
|  |         return cls._recurse_apply(t, simple_to_eager) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def eager_to_meta(cls, t: Any) -> Any: | ||||||
|  |         return cls.meta_with_dtype(t, t.dtype) | ||||||
|  |  | ||||||
|  |     # must be overridden, meta tensor init is backend-specific | ||||||
|  |     @classmethod | ||||||
|  |     @abstractmethod | ||||||
|  |     def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_eager(cls, t: Any) -> Any: | ||||||
|  |         if type(t) is cls: | ||||||
|  |             # already eager | ||||||
|  |             return t | ||||||
|  |         elif isinstance(t, cls._tensor_type): | ||||||
|  |             return cls(meta=cls.eager_to_meta(t), data=t) | ||||||
|  |         else: | ||||||
|  |             return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LazyNumpyTensor(LazyBase): | ||||||
|  |     _tensor_type = np.ndarray | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]: | ||||||
|  |         # The initial idea was to use np.nan as the fill value, | ||||||
|  |         # but non-float types like np.int16 can't use that. | ||||||
|  |         # So zero it is. | ||||||
|  |         cheat = np.zeros(1, dtype) | ||||||
|  |         return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape)) | ||||||
|  |  | ||||||
|  |     def astype(self, dtype, *args, **kwargs): | ||||||
|  |         meta = type(self).meta_with_dtype(self._meta, dtype) | ||||||
|  |         full_args = (self, dtype,) + args | ||||||
|  |         # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. | ||||||
|  |         return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs))) | ||||||
|  |  | ||||||
|  |     def tofile(self, *args, **kwargs): | ||||||
|  |         eager = LazyNumpyTensor.to_eager(self) | ||||||
|  |         return eager.tofile(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     # TODO: __array_function__ | ||||||
		Reference in New Issue
	
	Block a user
	 compilade
					compilade