convert : use reflinks for faster conversion

This commit is contained in:
Francis Couture-Harpin
2025-09-01 20:45:57 -04:00
parent e996f3aef8
commit 562aa42c12
6 changed files with 266 additions and 60 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any, Callable
import numpy as np
from numpy.typing import DTypeLike
from .utility import LocalTensorRange
logger = logging.getLogger(__name__)
@@ -20,10 +21,11 @@ class LazyMeta(ABCMeta):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze"),
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
return type(self)._wrap_fn(lambda s: getattr(s, name), use_self=self)()
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
@@ -39,8 +41,9 @@ class LazyMeta(ABCMeta):
def wrapped_special_op(self, *args, **kwargs):
return type(self)._wrap_fn(
getattr(type(self)._tensor_type, op_name),
use_self=self,
meta_noop=meta_noop,
)(self, *args, **kwargs)
)(*args, **kwargs)
return wrapped_special_op
# special methods bypass __getattr__, so they need to be added manually
@@ -76,14 +79,16 @@ class LazyBase(ABC, metaclass=LazyMeta):
_args: tuple
_kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
_ranges: tuple[LocalTensorRange, ...]
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None, ranges: tuple[LocalTensorRange, ...] = ()):
super().__init__()
self._meta = meta
self._data = data
self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func
self._ranges = ranges
assert self._func is not None or self._data is not None
def __init_subclass__(cls) -> None:
@@ -107,7 +112,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
return o
@classmethod
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False, data_noop: bool = False) -> Callable[[Any], Any]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
@@ -116,6 +121,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
ranges = use_self._ranges if use_self is not None and data_noop else ()
if isinstance(meta_noop, bool) and not meta_noop:
try:
res = fn(*meta_args, **kwargs)
@@ -138,7 +145,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn, ranges=ranges)
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
# share the evaluation between lazy tuple elements
shared_args: list = [args, None]
@@ -214,7 +221,8 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
ranges = self._ranges if self._meta.dtype == dtype else ()
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)