From df9c196ca1d60605de68410d308024330246f992 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 11 Jul 2024 14:42:11 +0100 Subject: [PATCH] basics ghstack-source-id: 8c34373f4fcd788636be7b87e8a554017df0c746 Pull Request resolved: https://github.com/pytorch/tensordict/pull/873 --- .../tensorclass/test_torch_functions.py | 96 ------ tensordict/__init__.py | 2 + tensordict/_contextlib.py | 15 +- tensordict/_lazy.py | 28 +- tensordict/_pytree.py | 10 +- tensordict/_td.py | 235 ++++++++------ tensordict/_torch_func.py | 38 ++- tensordict/base.py | 173 +++++++--- tensordict/functional.py | 3 +- tensordict/nn/params.py | 6 +- tensordict/tensorclass.py | 6 +- tensordict/utils.py | 262 ++++++++++++--- test/test_compile.py | 297 ++++++++++++++++++ test/test_tensordict.py | 6 +- 14 files changed, 834 insertions(+), 343 deletions(-) delete mode 100644 benchmarks/tensorclass/test_torch_functions.py create mode 100644 test/test_compile.py diff --git a/benchmarks/tensorclass/test_torch_functions.py b/benchmarks/tensorclass/test_torch_functions.py deleted file mode 100644 index 7ed717872..000000000 --- a/benchmarks/tensorclass/test_torch_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch - -from tensordict import tensorclass - - -@tensorclass -class MyData: - a: torch.Tensor - b: torch.Tensor - other: str - nested: "MyData" = None - - -@pytest.fixture -def a(): - return torch.zeros(300, 400, 50) - - -@pytest.fixture -def b(): - return torch.zeros(300, 400, 50) - - -@pytest.fixture -def tc(a, b): - return MyData( - a=a, - b=b, - other="hello", - nested=MyData( - a=a.clone(), b=b.clone(), other="goodbye", batch_size=[300, 400, 50] - ), - batch_size=[300, 400], - ) - - -def test_unbind(benchmark, tc): - benchmark(torch.unbind, tc, 0) - - -def test_full_like(benchmark, tc): - benchmark(torch.full_like, tc, 2.0) - - -def test_zeros_like(benchmark, tc): - benchmark( - torch.zeros_like, - tc, - ) - - -def test_ones_like(benchmark, tc): - benchmark( - torch.ones_like, - tc, - ) - - -def test_clone(benchmark, tc): - benchmark( - torch.clone, - tc, - ) - - -def test_squeeze(benchmark, tc): - benchmark( - torch.squeeze, - tc, - ) - - -def test_unsqueeze(benchmark, tc): - benchmark(torch.unsqueeze, tc, 0) - - -def test_split(benchmark, tc): - benchmark(torch.split, tc, [200, 100]) - - -def test_permute(benchmark, tc): - benchmark(torch.permute, tc, [1, 0]) - - -def test_stack(benchmark, tc): - benchmark(torch.stack, [tc] * 3, 0) - - -def test_cat(benchmark, tc): - benchmark(torch.cat, [tc] * 3, 0) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 7640bde08..f11328442 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -19,6 +19,7 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, + assert_close, is_batchedtensor, is_tensorclass, lazy_legacy, @@ -43,6 +44,7 @@ "TensorDict", "TensorDictBase", "assert_allclose_td", + "assert_close", "dense_stack_tds", "is_batchedtensor", "is_tensor_collection", diff --git a/tensordict/_contextlib.py b/tensordict/_contextlib.py index a70537028..36d21a35a 100644 --- a/tensordict/_contextlib.py +++ b/tensordict/_contextlib.py @@ -91,13 +91,14 @@ def context_decorator(ctx, func): be a multi-shot context manager that can be directly invoked multiple times) or a callable that produces a context manager. """ - assert not (callable(ctx) and hasattr(ctx, "__enter__")), ( - f"Passed in {ctx} is both callable and also a valid context manager " - "(has __enter__), making it ambiguous which interface to use. If you " - "intended to pass a context manager factory, rewrite your call as " - "context_decorator(lambda: ctx()); if you intended to pass a context " - "manager directly, rewrite your call as context_decorator(lambda: ctx)" - ) + if callable(ctx) and hasattr(ctx, "__enter__"): + raise RuntimeError( + f"Passed in {ctx} is both callable and also a valid context manager " + "(has __enter__), making it ambiguous which interface to use. If you " + "intended to pass a context manager factory, rewrite your call as " + "context_decorator(lambda: ctx()); if you intended to pass a context " + "manager directly, rewrite your call as context_decorator(lambda: ctx)" + ) if not callable(ctx): diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index c9fe7cbd8..378805174 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -45,10 +45,6 @@ from tensordict.utils import _ftdim_mock as ftdim _has_funcdim = False -from tensordict._C import ( # @manual=//tensordict:_C - _unravel_key_to_tuple, - unravel_key_list, -) from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict from tensordict.base import ( _is_leaf_nontensor, @@ -72,6 +68,7 @@ _renamed_inplace_method, _shape, _td_fields, + _unravel_key_to_tuple, as_decorator, cache, convert_ellipsis_to_idx, @@ -85,6 +82,7 @@ KeyedJaggedTensor, lock_blocked, NestedKey, + unravel_key_list, ) from torch import Tensor @@ -1168,7 +1166,9 @@ def _add_batch_dim(self, *, in_dim, vmap_level): td._fast_apply( lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level), batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim], - names=[name for i, name in enumerate(td.names) if i != in_dim], + names=[name for i, name in enumerate(td.names) if i != in_dim] + if self._has_names() + else None, ) for td in td.tensordicts ] @@ -1313,7 +1313,7 @@ def contiguous(self) -> T: source=source, batch_size=batch_size, device=device, - names=self.names, + names=self.names if self._has_names() else None, lock=self.is_locked, ) return out @@ -1377,10 +1377,6 @@ def _check_new_batch_size(self, new_size: torch.Size) -> None: super()._check_new_batch_size(new_size) def _change_batch_size(self, new_size: torch.Size) -> None: - if not hasattr(self, "_orig_batch_size"): - self._orig_batch_size = self.batch_size - elif self._orig_batch_size == new_size: - del self._orig_batch_size self._batch_size = new_size def keys( @@ -1552,7 +1548,7 @@ def _multithread_rebuild( # We know batch_size is None, this has been checked earlier batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, out: TensorDictBase | None = None, @@ -1603,7 +1599,7 @@ def _multithread_rebuild( ) else: out = self - if names is not None: + if names is not NO_DEFAULT: out.names = names return out @@ -1613,7 +1609,7 @@ def _apply_nest( *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, call_on_nested: bool = False, @@ -1719,7 +1715,7 @@ def _apply_nest( ) else: out = self - if names is not None: + if names is not NO_DEFAULT: out.names = names return out @@ -2980,10 +2976,6 @@ def _rename_subtds(self, names): ) def _change_batch_size(self, new_size: torch.Size) -> None: - if not hasattr(self, "_orig_batch_size"): - self._orig_batch_size = self.batch_size - elif self._orig_batch_size == new_size: - del self._orig_batch_size self._batch_size = new_size def _get_str(self, key, default): diff --git a/tensordict/_pytree.py b/tensordict/_pytree.py index 6109a3e5f..5fa594e7e 100644 --- a/tensordict/_pytree.py +++ b/tensordict/_pytree.py @@ -29,8 +29,10 @@ def _str_to_dict(str_spec: str) -> Tuple[List[str], str]: - assert str_spec[1] == "(" - assert str_spec[-1] == ")" + if str_spec[1] != "(" or str_spec[-1] != ")": + raise ValueError( + f"string must have '(' as a second character and ')' in last position. Got {str_spec}." + ) context_and_child_strings = str_spec[2:-1] child_strings = [] @@ -92,7 +94,7 @@ def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]: return values, { "keys": keys, "batch_size": d.batch_size, - "names": d.names, + "names": d.names if d._has_names() else None, "device": d.device, "constructor": _constructor(type(d)), "non_tensor_data": d.non_tensor_items(), @@ -159,7 +161,7 @@ def _td_flatten_with_keys( return [(MappingKey(k), v) for k, v in zip(keys, values)], { "keys": keys, "batch_size": d.batch_size, - "names": d.names, + "names": d.names if d._has_names() else None, "device": d.device, "constructor": _constructor(type(d)), "non_tensor_data": d.non_tensor_items(), diff --git a/tensordict/_td.py b/tensordict/_td.py index be8f2dc67..6ab1db2f5 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -48,7 +48,6 @@ _get_shape_from_args, _getitem_batch_size, _index_preserve_data_ptr, - _is_number, _is_shared, _is_tensorclass, _KEY_ERROR, @@ -61,6 +60,7 @@ _set_max_batch_size, _shape, _STRDTYPE2DTYPE, + _StringKeys, _StringOnlyDict, _sub_index, _unravel_key_to_tuple, @@ -80,6 +80,8 @@ unravel_key_list, ) from torch import Tensor + +from torch._dynamo import graph_break from torch.jit._shape_functions import infer_size_impl from torch.utils._pytree import tree_map @@ -225,6 +227,8 @@ def __init__( non_blocking: bool = None, lock: bool = False, ) -> None: + if names and torch.compiler.is_dynamo_compiling(): + graph_break() has_device = False sub_non_blocking = False if device is not None: @@ -269,6 +273,15 @@ def _new_unsafe( lock: bool = False, nested: bool = True, ) -> TensorDict: + if torch.compiler.is_dynamo_compiling(): + return TensorDict( + source, + batch_size=batch_size, + device=device, + names=names, + non_blocking=non_blocking, + lock=lock, + ) self = cls.__new__(cls) sub_non_blocking = False if device is not None: @@ -294,6 +307,8 @@ def _new_unsafe( non_blocking=sub_non_blocking, ) _tensordict[key] = value + # assert names is None or len(names) == self.batch_dims, (names, batch_size) + # assert (names is None) or (not all(name is None for name in names)) self._td_dim_names = names if lock: self.lock_() @@ -1008,7 +1023,7 @@ def _multithread_rebuild( *, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, out: TensorDictBase | None = None, @@ -1043,9 +1058,12 @@ def _multithread_rebuild( else: def make_result(names=names, batch_size=batch_size): - if batch_size is not None and names is None: - # erase names - names = [None] * len(batch_size) + if names is NO_DEFAULT: + if batch_size is not None: + # erase names + names = None + elif batch_size is None: + names = self.names if self._has_names() else None return self.empty(batch_size=batch_size, device=device, names=names) result = make_result() @@ -1127,7 +1145,7 @@ def _apply_nest( *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, call_on_nested: bool = False, @@ -1159,9 +1177,12 @@ def _apply_nest( else: def make_result(names=names, batch_size=batch_size): - if batch_size is not None and names is None: - # erase names - names = [None] * len(batch_size) + if names is NO_DEFAULT: + if batch_size is not None: + # erase names + names = None + else: + names = self.names if self._has_names() else None return self.empty(batch_size=batch_size, device=device, names=names) result = None @@ -1274,7 +1295,9 @@ def _add_batch_dim_wrapper(key, value): batch_size=torch.Size( [b for i, b in enumerate(td.batch_size) if i != in_dim] ), - names=[name for i, name in enumerate(td.names) if i != in_dim], + names=[name for i, name in enumerate(td.names) if i != in_dim] + if self._has_names() + else None, lock=self.is_locked, ) return out @@ -1375,6 +1398,11 @@ def expand(self, *args, **kwargs) -> T: f"as the original length. target_shape = {shape}, existing_shape = {self.batch_size}" ) + if self._has_names(): + names = [None] * (len(shape) - tensordict_dims) + self.names + else: + names = None + def _expand(tensor): tensor_shape = tensor.shape tensor_dims = len(tensor_shape) @@ -1385,7 +1413,6 @@ def _expand(tensor): new_shape = shape return tensor.expand(new_shape) - names = [None] * (len(shape) - tensordict_dims) + self.names return self._fast_apply( _expand, batch_size=shape, @@ -1400,6 +1427,12 @@ def _unbind(self, dim: int): if self._has_names(): names = copy(self.names) names = [name for i, name in enumerate(names) if i != dim] + # We could use any() but dynamo doesn't like generators + for name in names: + if name is not None: + break + else: + names = None device = self.device is_shared = self._is_shared @@ -1510,7 +1543,7 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas else: raise TypeError(WRONG_TYPE) index = (slice(None),) * dim - names = self.names + names = self.names if self._has_names() else None return tuple( self._index_tensordict(index + (ss,), new_batch_size=bs, names=names) for ss, bs in zip(split_sizes, batch_sizes) @@ -1570,7 +1603,10 @@ def _reshape(tensor): return tensor.reshape((*shape, *tensor.shape[batch_dims:])) return self._fast_apply( - _reshape, batch_size=shape, call_on_nested=True, propagate_lock=True + _reshape, + batch_size=shape, + call_on_nested=True, + propagate_lock=True, ) def _transpose(self, dim0, dim1): @@ -1643,10 +1679,17 @@ def _permute(tensor): def _squeeze(self, dim=None): batch_size = self.batch_size if dim is None: - names = list(self.names) - batch_size, names = zip( - *[(size, name) for size, name in zip(batch_size, names) if size != 1] - ) + names = copy(self.names) if self._has_names() else None + if names is not None: + batch_size, names = zip( + *[ + (size, name) + for size, name in zip(batch_size, names) + if size != 1 + ] + ) + else: + batch_size = [size for size in batch_size if size != 1] batch_size = torch.Size(batch_size) if batch_size == self.batch_size: return self @@ -1681,8 +1724,9 @@ def _squeeze(tensor): batch_size = list(batch_size) batch_size.pop(dim) batch_size = list(batch_size) - names = list(self.names) - names.pop(dim) + names = copy(self.names) if self._has_names() else None + if names: + names.pop(dim) result = self._fast_apply( lambda x: x.squeeze(newdim), @@ -1712,8 +1756,9 @@ def _unsqueeze(self, dim): batch_size.insert(newdim, 1) batch_size = torch.Size(batch_size) - names = copy(self.names) - names.insert(newdim, None) + names = copy(self.names) if self._has_names() else None + if names: + names.insert(newdim, None) def _unsqueeze(tensor): return tensor.unsqueeze(newdim) @@ -1768,7 +1813,7 @@ def _from_dict_validated( input_dict, batch_size=torch.Size(batch_size), device=torch.device(device) if device is not None else device, - names=names, + names=names if any(name is not None for name in names) else None, ) def from_dict_instance( @@ -1848,70 +1893,32 @@ def names(self): names = self._td_dim_names if names is None: return [None for _ in range(self.batch_dims)] - return names - - def _get_names_idx(self, idx): - if not self._has_names(): - names = None - else: - - def is_boolean(idx): - try: - from functorch import dim as ftdim - - except ImportError: - from tensordict.utils import _ftdim_mock as ftdim - - if isinstance(idx, ftdim.Dim): - return None - if isinstance(idx, tuple) and len(idx) == 1: - return is_boolean(idx[0]) - if hasattr(idx, "dtype") and idx.dtype is torch.bool: - return idx.ndim - return None - - num_boolean_dim = is_boolean(idx) - names = self.names - if num_boolean_dim: - names = [None] + names[num_boolean_dim:] - else: - if not isinstance(idx, tuple): - idx = (idx,) - if len([_idx for _idx in idx if _idx is not None]) < self.ndim: - idx = (*idx, Ellipsis) - idx_names = convert_ellipsis_to_idx(idx, self.batch_size) - # this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3] - count = 0 - idx_to_take = [] - no_more_tensors = False - for _idx in idx_names: - if _idx is None: - idx_to_take.append(None) - elif _is_number(_idx): - count += 1 - elif isinstance(_idx, (torch.Tensor, np.ndarray)): - if not no_more_tensors: - idx_to_take.extend([count] * _idx.ndim) - count += 1 - no_more_tensors = True - else: - # skip this one - count += 1 - else: - idx_to_take.append(count) - count += 1 - names = [names[i] if i is not None else None for i in idx_to_take] + # assert len(names) == self.batch_dims, (names, self.batch_dims) return names @names.setter def names(self, value): + if torch.compiler.is_dynamo_compiling() or torch.compiler.is_compiling(): + if value is not None: + graph_break() + else: + # We have already made sure that the tensordict was not named + return + # we don't run checks on types for efficiency purposes if value is None: self._rename_subtds(value) self._erase_names() return value = list(value) - num_none = sum(v is None for v in value) + # Faster but incompatible with dynamo + # num_none = sum(v is None for v in value) + num_none = 0 + for v in value: + num_none += v is None + if num_none == self.batch_dims: + self.names = None + return if num_none: num_none -= 1 if len(set(value)) != len(value) - num_none: @@ -1963,10 +1970,6 @@ def batch_size(self, new_size: torch.Size) -> None: self._batch_size_setter(new_size) def _change_batch_size(self, new_size: torch.Size) -> None: - if not hasattr(self, "_orig_batch_size"): - self._orig_batch_size = self.batch_size - elif self._orig_batch_size == new_size: - del self._orig_batch_size self._batch_size = new_size # Checks @@ -2776,7 +2779,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, - names=copy(self._td_dim_names), + names=copy(self._td_dim_names) if self._has_names() else None, ) # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too # This may be undesirable, not sure if this should be the default behaviour @@ -2793,12 +2796,12 @@ def contiguous(self) -> T: source=source, batch_size=batch_size, device=device, - names=self.names, + names=self.names if self._has_names() else None, ) return out def empty( - self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=None + self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=NO_DEFAULT ) -> T: if not recurse: return TensorDict._new_unsafe( @@ -2807,7 +2810,9 @@ def empty( if batch_size is None else torch.Size(batch_size), source={}, - names=self._td_dim_names if names is None else names, + names=(self.names if self._has_names() else None) + if names is NO_DEFAULT + else names, ) return super().empty(recurse=recurse) @@ -2915,7 +2920,7 @@ def keys( is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: if not include_nested and not leaves_only: - return self._tensordict.keys() + return _StringKeys(self._tensordict.keys()) else: return self._nested_keys( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf @@ -2946,20 +2951,55 @@ def items( return self._tensordict.items() elif include_nested and leaves_only: is_leaf = _default_is_leaf if is_leaf is None else is_leaf + result = [] + if torch.compiler.is_dynamo_compiling(): - def fast_iter(): - for k, val in self._tensordict.items(): - if not is_leaf(val.__class__): - yield from ( - ((k, *((_key,) if isinstance(_key, str) else _key)), _val) + def fast_iter(): + for key, val in self._tensordict.items(): + if not is_leaf(val.__class__): for _key, _val in val.items( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + ): + result.append( + ( + ( + key, + *( + (_key,) + if isinstance(_key, str) + else _key + ), + ), + _val, + ) + ) + else: + result.append((key, val)) + return result + + else: + # dynamo doesn't like generators + def fast_iter(): + for key, val in self._tensordict.items(): + if not is_leaf(val.__class__): + yield from ( + ( + ( + key, + *((_key,) if isinstance(_key, str) else _key), + ), + _val, + ) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - else: - yield k, val + else: + yield (key, val) return fast_iter() else: @@ -2976,7 +3016,8 @@ def values( if not include_nested and not leaves_only: return self._tensordict.values() else: - return super().values( + return TensorDictBase.values( + self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, @@ -3285,10 +3326,6 @@ def to(self, *args, **kwargs: Any) -> T: return self.to_tensordict().to(*args, **kwargs) def _change_batch_size(self, new_size: torch.Size) -> None: - if not hasattr(self, "_orig_batch_size"): - self._orig_batch_size = self.batch_size - elif self._orig_batch_size == new_size: - del self._orig_batch_size self._batch_size = new_size def get( @@ -3541,7 +3578,7 @@ def contiguous(self) -> T: batch_size=self.batch_size, source={key: value.contiguous() for key, value in self.items()}, device=self.device, - names=self.names, + names=self.names if self._has_names() else None, ) def _select( diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 796d510c3..97acfc728 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -5,6 +5,7 @@ from __future__ import annotations +import contextlib import functools from typing import Any, Callable, Sequence, TypeVar @@ -275,12 +276,14 @@ def _cat( if out is None: out = {} for key in keys: - with _ErrorInteceptor( - key, "Attempted to concatenate tensors on different devices at key" - ): - out[key] = torch.cat( - [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], dim - ) + items = [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts] + if not torch.compiler.is_dynamo_compiling(): + with _ErrorInteceptor( + key, "Attempted to concatenate tensors on different devices at key" + ): + out[key] = torch.cat(items, dim) + else: + out[key] = torch.cat(items, dim) if device is None: device = list_of_tensordicts[0].device for td in list_of_tensordicts[1:]: @@ -306,7 +309,7 @@ def _cat( for key in keys: with _ErrorInteceptor( key, "Attempted to concatenate tensors on different devices at key" - ): + ) if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext(): if isinstance(out, TensorDict): torch.cat( [td.get(key) for td in list_of_tensordicts], @@ -403,11 +406,14 @@ def _stack( ) -> T: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - + is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) if all(is_non_tensor(td) for td in list_of_tensordicts): from tensordict.tensorclass import NonTensorData return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + elif is_tc: + tc_type = type(list_of_tensordicts[0]) + list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] batch_size = list_of_tensordicts[0].batch_size if dim < 0: @@ -504,9 +510,12 @@ def _stack( tensor = _tensordict._get_str(key, default=NO_DEFAULT) if is_tensor is None: tensor_cls = type(tensor) - is_tensor = ( - not _is_tensor_collection(tensor_cls) - ) or is_tensorclass(tensor_cls) + # is_tensor = ( + # not _is_tensor_collection(tensor_cls) + # ) or is_tensorclass(tensor_cls) + # TODO: make sense of this, dynamo cannot pass through stack (and it's unsafe) + # only tensors should be tensors + is_tensor = not _is_tensor_collection(tensor_cls) if is_not_init is None: is_not_init = isinstance(tensor, UninitializedTensorMixin) if not is_not_init: @@ -537,7 +546,7 @@ def stack_fn(key, values, is_not_init, is_tensor): return torch.stack(values, dim) with _ErrorInteceptor( key, "Attempted to stack tensors on different devices at key" - ): + ) if not torch.compiler.torch.compiler.is_dynamo_compiling() else contextlib.nullcontext(): return _stack(values, dim, maybe_dense_stack=maybe_dense_stack) out = { @@ -545,13 +554,16 @@ def stack_fn(key, values, is_not_init, is_tensor): for key, (values, is_not_init, is_tensor) in out.items() } - return TensorDict._new_unsafe( + result = TensorDict._new_unsafe( out, batch_size=LazyStackedTensorDict._compute_batch_size( batch_size, dim, len(list_of_tensordicts) ), device=device, ) + if is_tc: + return tc_type.from_tensordict(result) + return result else: out = LazyStackedTensorDict( *list_of_tensordicts, diff --git a/tensordict/base.py b/tensordict/base.py index 62e72b26c..eaed033a7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9,6 +9,7 @@ import collections import concurrent.futures import contextlib +import enum import gc import importlib import numbers @@ -52,8 +53,10 @@ _GENERIC_NESTED_ERR, _get_shape_from_args, _is_non_tensor, + _is_number, _is_tensorclass, _KEY_ERROR, + _lock_warn, _make_dtype_promotion, _parse_to, _prefix_last_key, @@ -93,17 +96,11 @@ # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key, default=None)` is a valid usage. -class _NoDefault: - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(_NoDefault, cls).__new__(cls) - return cls.instance - - def __bool__(self): - return False +class _NoDefault(enum.IntEnum): + ZERO = 0 -NO_DEFAULT = _NoDefault() +NO_DEFAULT = _NoDefault.ZERO T = TypeVar("T", bound="TensorDictBase") @@ -1382,12 +1379,18 @@ def _batch_size_setter(self, new_batch_size: torch.Size) -> None: key, value, inplace=True, validated=True, non_blocking=False ) self._check_new_batch_size(new_batch_size) - self._change_batch_size(new_batch_size) - if self._has_names(): + has_names = self._has_names() + if has_names: # if the tensordict has dim names and the new batch-size has more dims, # we can simply add empty names after the current ones. # Otherwise, we discard the extra existing names. names = self.names + self._erase_names() + self._change_batch_size(new_batch_size) + if has_names: + # if the tensordict has dim names and the new batch-size has more dims, + # we can simply add empty names after the current ones. + # Otherwise, we discard the extra existing names. if len(names) < len(new_batch_size): self.names = names + [None] * (len(new_batch_size) - len(names)) else: @@ -2159,6 +2162,60 @@ def names(self): """ ... + def _get_names_idx(self, idx): + if not self._has_names(): + return None + + def is_boolean(idx): + try: + from functorch import dim as ftdim + + except ImportError: + from tensordict.utils import _ftdim_mock as ftdim + + if isinstance(idx, ftdim.Dim): + return None + if isinstance(idx, tuple) and len(idx) == 1: + return is_boolean(idx[0]) + if hasattr(idx, "dtype") and idx.dtype is torch.bool: + return idx.ndim + return None + + num_boolean_dim = is_boolean(idx) + names = self.names + if num_boolean_dim: + names = [None] + names[num_boolean_dim:] + else: + if not isinstance(idx, tuple): + idx = (idx,) + if len([_idx for _idx in idx if _idx is not None]) < self.ndim: + idx = (*idx, Ellipsis) + idx_names = convert_ellipsis_to_idx(idx, self.batch_size) + # this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3] + count = 0 + idx_to_take = [] + no_more_tensors = False + for _idx in idx_names: + if _idx is None: + idx_to_take.append(None) + elif _is_number(_idx): + count += 1 + elif isinstance(_idx, (torch.Tensor, np.ndarray)): + if not no_more_tensors: + idx_to_take.extend([count] * _idx.ndim) + count += 1 + no_more_tensors = True + else: + # skip this one + count += 1 + else: + idx_to_take.append(count) + count += 1 + names = [names[i] if i is not None else None for i in idx_to_take] + if all(name is None for name in names): + return None + return names + @abc.abstractmethod def _erase_names(self): """Erases the dimension names from a tensordict.""" @@ -3854,7 +3911,7 @@ def _set_non_tensor(self, key: NestedKey, value: Any): self._set_str( key, NonTensorData( - value, + data=value, batch_size=self.batch_size, device=self.device, names=self.names if self._has_names() else None, @@ -4709,8 +4766,12 @@ def _values_list( is_leaf=is_leaf, collapse=collapse, ) - source = dict(zip(keys, vals)) - return [source[key] for key in sorting_keys] + if torch.compiler.is_dynamo_compiling(): + key_to_index = {key: i for i, key in enumerate(keys)} + return [vals[key_to_index[key]] for key in sorting_keys] + else: + source = dict(zip(keys, vals)) + return [source[key] for key in sorting_keys] @cache # noqa: B019 def _items_list( @@ -4721,16 +4782,23 @@ def _items_list( collapse: bool = False, is_leaf: Callable[[Type], bool] | None = None, ) -> Tuple[List, List]: - return tuple( - tuple(key_or_val) - for key_or_val in zip( - *self.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else is_leaf, - ) - ) + # return tuple( + # tuple(key_or_val) + # for key_or_val in zip( + # *self.items( + # include_nested=include_nested, + # leaves_only=leaves_only, + # is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else is_leaf, + # ) + # ) + # ) + items = self.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else None, ) + keys, vals = zip(*items) + return list(keys), list(vals) @cache # noqa: B019 def _grad(self): @@ -4798,7 +4866,7 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType: self.del_(key) except KeyError as err: # if default provided, 'out' value will return, else raise error - if default == NO_DEFAULT: + if default is NO_DEFAULT: raise KeyError( f"You are trying to pop key `{key}` which is not in dict " f"without providing default value." @@ -4876,7 +4944,6 @@ def flatten(tensor): else: batch_size = [nelt] + list(self.batch_size[end_dim + 1 :]) # TODO: check that this works with nested tds of different batch size - out = self._fast_apply(flatten, batch_size=batch_size, propagate_lock=True) if self._has_names(): names = [ name @@ -4884,7 +4951,11 @@ def flatten(tensor): if (i < start_dim or i > end_dim) ] names.insert(start_dim, None) - out.names = names + else: + names = None + out = self._fast_apply( + flatten, batch_size=batch_size, propagate_lock=True, names=names + ) return out @as_decorator() @@ -5515,7 +5586,7 @@ def apply( *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, default: Any = NO_DEFAULT, filter_empty: bool | None = None, @@ -5669,7 +5740,7 @@ def named_apply( nested_keys: bool = False, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, default: Any = NO_DEFAULT, filter_empty: bool | None = None, @@ -5866,7 +5937,7 @@ def _multithread_rebuild( *, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, out: TensorDictBase | None = None, @@ -5885,7 +5956,7 @@ def _multithread_apply_nest( *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, call_on_nested: bool = False, @@ -5962,7 +6033,7 @@ def _apply_nest( *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, call_on_nested: bool = False, @@ -5983,7 +6054,7 @@ def _fast_apply( *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, call_on_nested: bool = False, default: Any = NO_DEFAULT, @@ -8488,7 +8559,12 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T: result.lock_() return result else: - for key in list(self.keys()): + if not torch.compiler.is_dynamo_compiling(): + key_list = list(self.keys()) + else: + key_list = [k for k in self.keys()] # noqa + + for key in key_list: if separator in key: new_key = tuple(key.split(separator)) try: @@ -8522,22 +8598,27 @@ def is_locked(self, value: bool) -> None: def _propagate_lock(self, lock_parents_weakrefs=None): """Registers the parent tensordict that handles the lock.""" + self._is_locked = True if self._is_locked and lock_parents_weakrefs is not None: lock_parents_weakrefs = [ ref for ref in lock_parents_weakrefs if not any(refref is ref for refref in self._lock_parents_weakrefs) ] - self._is_locked = True - is_root = lock_parents_weakrefs is None - if is_root: - lock_parents_weakrefs = [] + is_compiling = torch.compiler.is_dynamo_compiling() + if not is_compiling: + is_root = lock_parents_weakrefs is None + if is_root: + lock_parents_weakrefs = [] + else: + self._lock_parents_weakrefs = ( + self._lock_parents_weakrefs + lock_parents_weakrefs + ) + lock_parents_weakrefs = list(lock_parents_weakrefs) + lock_parents_weakrefs.append(weakref.ref(self)) else: - self._lock_parents_weakrefs = ( - self._lock_parents_weakrefs + lock_parents_weakrefs - ) - lock_parents_weakrefs = copy(lock_parents_weakrefs) - lock_parents_weakrefs.append(weakref.ref(self)) + _lock_warn() + for value in self.values(): if _is_tensor_collection(type(value)): value._propagate_lock(lock_parents_weakrefs) @@ -8799,7 +8880,8 @@ def to(tensor): def _sync_all(self): if _has_cuda: - if torch.cuda.is_initialized(): + # TODO: dynamo doesn't like torch.cuda.is_initialized + if not torch.compiler.is_dynamo_compiling() and torch.cuda.is_initialized(): torch.cuda.synchronize() elif _has_mps: torch.mps.synchronize() @@ -8959,9 +9041,8 @@ def _register_tensor_class(cls): def _is_tensor_collection(datatype): - try: - out = _TENSOR_COLLECTION_MEMO[datatype] - except KeyError: + out = _TENSOR_COLLECTION_MEMO.get(datatype) + if out is None: if issubclass(datatype, TensorDictBase): out = True elif _is_tensorclass(datatype): diff --git a/tensordict/functional.py b/tensordict/functional.py index ea11d0423..d7fe797c6 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -334,8 +334,7 @@ def dense_stack_tds( shape = list(td_list[0].shape) shape.insert(dim, len(td_list)) - out = td_list[0].unsqueeze(dim).expand(shape).clone() - return torch.stack(td_list, dim=dim, out=out) + return TensorDict.maybe_dense_stack(td_list, dim=dim) def make_tensordict( diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index a2a9fc2e9..40fd687fa 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -488,7 +488,7 @@ def apply( *others: TensorDictBase, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, default: Any = NO_DEFAULT, filter_empty: bool | None = None, @@ -504,7 +504,7 @@ def named_apply( *others: TensorDictBase, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, default: Any = NO_DEFAULT, filter_empty: bool | None = None, @@ -540,7 +540,7 @@ def _multithread_rebuild( *, batch_size: Sequence[int] | None = None, device: torch.device | None = NO_DEFAULT, - names: Sequence[str] | None = None, + names: Sequence[str] | None = NO_DEFAULT, inplace: bool = False, checked: bool = False, out: TensorDictBase | None = None, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0011f6424..c51701bf3 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -2172,7 +2172,7 @@ def __post_init__(self): data_inner = data.tolist() del _tensordict["data"] _non_tensordict["data"] = data_inner - assert _tensordict.is_empty(), self._tensordict + # assert _tensordict.is_empty(), self._tensordict def __repr__(self): data_str = str(self.data) @@ -2507,13 +2507,13 @@ def clone(self, recurse: bool = True): data=deepcopy(self.data), batch_size=self.batch_size, device=self.device, - names=self.names, + names=self.names if self._has_names() else None, ) return type(self)( data=self.data, batch_size=self.batch_size, device=self.device, - names=self.names, + names=self.names if self._has_names() else None, ) def share_memory_(self): diff --git a/tensordict/utils.py b/tensordict/utils.py index 7383082f5..f447aab21 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -47,10 +47,10 @@ _has_funcdim = False from packaging.version import parse from tensordict._C import ( # noqa: F401 # @manual=//tensordict:_C - _unravel_key_to_tuple, - unravel_key, - unravel_key_list, - unravel_keys, + _unravel_key_to_tuple as _unravel_key_to_tuple_cpp, + unravel_key as unravel_key_cpp, + unravel_key_list as unravel_key_list_cpp, + unravel_keys as unravel_keys_cpp, ) from tensordict._contextlib import _DecoratorContextManager @@ -1125,7 +1125,7 @@ def cache(fun): @wraps(fun) def newfun(_self: "TensorDictBase", *args, **kwargs): - if not _self.is_locked: + if not _self.is_locked or torch.compiler.is_compiling(): return fun(_self, *args, **kwargs) cache = _self._cache if cache is None: @@ -1162,31 +1162,26 @@ def new_fun(self, *args, **kwargs): class _StringKeys(KeysView): - """A key view where contains is restricted to strings.""" + """A key view where contains is restricted to strings. - def __contains__(self, item): - if not isinstance(item, str): - try: - unravel_item = _unravel_key_to_tuple(item) - if not unravel_item: # catch errors during unravel - raise TypeError - except Exception: - raise TypeError(_NON_STR_KEY_ERR) - if len(unravel_item) > 1: - raise TypeError(_NON_STR_KEY_TUPLE_ERR) - else: - item = unravel_item[0] - return super().__contains__(item) + Saving the keys as an attribute is 25% faster than just subclassing KeysView. + + """ + + def __init__(self, keys): + self.keys = keys + def __getitem__(self, key): + return self.keys.__getitem__(key) -class _StringOnlyDict(dict): - """A dict class where contains is restricted to strings.""" + def __iter__(self): + yield from self.keys + + def __repr__(self): + return f"{type(self)}({self.keys})" - # kept here for debugging - # def __setitem__(self, key, value): - # if not isinstance(key, str): - # raise RuntimeError - # return super().__setitem__(key, value) + def __len__(self): + return len(self.keys) def __contains__(self, item): if not isinstance(item, str): @@ -1200,10 +1195,10 @@ def __contains__(self, item): raise TypeError(_NON_STR_KEY_TUPLE_ERR) else: item = unravel_item[0] - return super().__contains__(item) + return self.keys.__contains__(item) - def keys(self): - return _StringKeys(self) + +_StringOnlyDict = dict def lock_blocked(func): @@ -1218,7 +1213,46 @@ def new_func(self, *args, **kwargs): return new_func -class as_decorator: +# class as_decorator: +# """Converts a method to a decorator. +# +# Examples: +# >>> from tensordict import TensorDict +# >>> data = TensorDict({}, []) +# >>> with data.lock_(): # lock_ is decorated +# ... assert data.is_locked +# >>> assert not data.is_locked +# """ +# +# def __init__(self, attr=None): +# self.attr = attr +# +# def __call__(self, func): +# if self.attr is not None: +# +# @wraps(func) +# def new_func(_self, *args, **kwargs): +# _attr_pre = getattr(_self, self.attr) +# out = func(_self, *args, **kwargs) +# _attr_post = getattr(_self, self.attr) +# if out is not None: +# if _attr_post is not _attr_pre: +# out._last_op = (new_func.__name__, (args, kwargs, _self)) +# else: +# out._last_op = None +# return out +# +# else: +# +# @wraps(func) +# def new_func(_self, *args, **kwargs): +# out = func(_self, *args, **kwargs) +# if out is not None: +# out._last_op = (new_func.__name__, (args, kwargs, _self)) +# return out +# +# return new_func +def as_decorator(attr=None): """Converts a method to a decorator. Examples: @@ -1229,17 +1263,14 @@ class as_decorator: >>> assert not data.is_locked """ - def __init__(self, attr=None): - self.attr = attr - - def __call__(self, func): - if self.attr is not None: + def __call__(func): + if attr is not None: @wraps(func) def new_func(_self, *args, **kwargs): - _attr_pre = getattr(_self, self.attr) + _attr_pre = getattr(_self, attr) out = func(_self, *args, **kwargs) - _attr_post = getattr(_self, self.attr) + _attr_post = getattr(_self, attr) if out is not None: if _attr_post is not _attr_pre: out._last_op = (new_func.__name__, (args, kwargs, _self)) @@ -1258,6 +1289,8 @@ def new_func(_self, *args, **kwargs): return new_func + return __call__ + def _find_smallest_uint(N): if not hasattr(torch, "uint32"): @@ -1370,9 +1403,27 @@ def _parse_to(*args, **kwargs): pin_memory = kwargs.pop("pin_memory", False) num_threads = kwargs.pop("num_threads", None) other = kwargs.pop("other", None) - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( - *args, **kwargs - ) + if not torch.compiler.is_dynamo_compiling(): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + else: + device = None + dtype = None + non_blocking = kwargs.get("non_blocking", False) + convert_to_format = kwargs.get("convert_to_format", None) + if len(args) > 0: + device = torch.device(args[0]) + if len(args) > 1: + dtype = args[1] + else: + dtype = kwargs.get("dtype", None) + else: + device = kwargs.get("device", None) + dtype = kwargs.get("dtype", None) + if device is not None: + device = torch.device(device) + if other is not None: if device is not None and device != other.device: raise ValueError("other and device cannot be both passed") @@ -1480,7 +1531,7 @@ def _get_leaf_tensordict( return tensordict, key[0] -def assert_allclose_td( +def assert_close( actual: T, expected: T, rtol: float | None = None, @@ -1631,13 +1682,16 @@ def _check_keys( if not len(list_of_tensordicts): return set() - keys: set[str] = set( - list_of_tensordicts[0].keys( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=_is_leaf_nontensor, - ) + keys = list_of_tensordicts[0].keys( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_is_leaf_nontensor, ) + # TODO: compile doesn't like set() over an arbitrary object + if torch.compiler.is_dynamo_compiling(): + keys = {k for k in keys} # noqa: C416 + else: + keys: set[str] = set(keys) for td in list_of_tensordicts[1:]: k = td.keys( include_nested=include_nested, @@ -1647,7 +1701,11 @@ def _check_keys( if not strict: keys = keys.intersection(k) else: - if set(k) != keys: + if torch.compiler.is_dynamo_compiling(): + k = {v for v in k} # noqa: C416 + else: + k = set(k) + if k != keys: raise KeyError( f"got keys {keys} and {set(td.keys())} which are incompatible" ) @@ -1886,8 +1944,14 @@ def _getitem_batch_size(batch_size, index): boolean = False if isinstance(idx, (range, list)): shape = len(idx) - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype == torch.bool or idx.dtype == np.dtype("bool"): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + shape = torch.Size([idx.sum()]) + boolean = True + else: + shape = idx.shape + elif isinstance(idx, np.ndarray): + if idx.dtype == np.dtype("bool"): shape = torch.Size([idx.sum()]) boolean = True else: @@ -1927,7 +1991,10 @@ def _getitem_batch_size(batch_size, index): continue elif isinstance(idx, slice): batch = batch_size[count] - out.append(len(range(*idx.indices(batch)))) + if torch.compiler.is_dynamo_compiling(): + out.append(len(range(*_slice_indices(idx, batch)))) + else: + out.append(len(range(*idx.indices(batch)))) count += 1 if batch_size[count:]: out.extend(batch_size[count:]) @@ -2388,6 +2455,86 @@ def new_func(self): return new_func +def _unravel_key_to_tuple(key): + if not torch.compiler.is_dynamo_compiling(): + return _unravel_key_to_tuple_cpp(key) + if isinstance(key, str): + return (key,) + if not isinstance(key, tuple): + return () + return tuple(subk for k in key for subk in _unravel_key_to_tuple(k)) + + +def unravel_key(key): + """Unravel a nested key. + + Examples: + >>> unravel_key("a") + "a" + >>> unravel_key(("a",)) + "a" + >>> unravel_key((("a", ("b",)))) + ("a", "b") + + """ + if not torch.compiler.is_dynamo_compiling(): + return unravel_key_cpp(key) + if isinstance(key, str): + return key + if isinstance(key, tuple): + if len(key) == 1: + return unravel_key(key[0]) + return tuple(unravel_key(_key) for _key in key) + raise ValueError("the key must be a str or a tuple of str") + + +def unravel_keys(*keys): + """Unravels a sequence of keys.""" + if not torch.compiler.is_dynamo_compiling(): + return unravel_keys_cpp(*keys) + return tuple(unravel_key(key) for key in keys) + + +def unravel_key_list(keys): + """Unravels a list of keys.""" + if not torch.compiler.is_dynamo_compiling(): + return unravel_key_list_cpp(keys) + return [unravel_key(key) for key in keys] + + +def _slice_indices(index: slice, len: int): + """A pure python implementation of slice.indices(len) since torch.compile doesn't recognise it.""" + step = index.step + if step is None: + step = 1 + elif step == 0: + raise ValueError("Step cannot be zero.") + + start = index.start + stop = index.stop + if start is None: + if step > 0: + start = 0 + else: + start = len - 1 + elif start < 0: + start = max(0, len + start) + + if stop is None: + if step > 0: + stop = len + else: + stop = -1 + elif stop > 0: + stop = min(len, stop) + elif step < 0 or (step > 0 and start >= 0): + stop = len + stop + return start, stop, step + + +assert_allclose_td = assert_close + + def _prefix_last_key(key, prefix): if isinstance(key, str): return prefix + key @@ -2403,3 +2550,16 @@ def _prefix_last_key(key, prefix): ) _DEVICE2STRDEVICE = KeyDependentDefaultDict(lambda key: str(key)) + + +def _lock_warn(): + warnings.warn( + "Using lock_() in a compiled graph should " + "only be done if users make sure that the code runs in eager mode. " + "torch.compile doesn't support weakrefs which are used to reference root tensordicts " + "to sub-tensordict and prevent unlocking a node when the graph is locked. " + "Such operation will fail in eager mode but won't be captured by torch.compile." + ) + + +_lock_warn = torch.compiler.assume_constant_result(_lock_warn) diff --git a/test/test_compile.py b/test/test_compile.py new file mode 100644 index 000000000..a513085e6 --- /dev/null +++ b/test/test_compile.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import contextlib + +import pytest + +import torch + +from tensordict import assert_close, TensorDict + +TORCH_VERSION = torch.__version__ + + +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.parametrize("mode", [None, "reduce-overhead"]) +class TestTD: + def test_tensor_output(self, mode): + def add_one(td): + return td["a", "b"] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": 0}}) + assert add_one(data) == 1 + assert add_one_c(data) == 1 + assert add_one_c(data + 1) == 2 + + def test_td_output(self, mode): + def add_one(td): + td["a", "c"] = td["a", "b"] + 1 + return td + + add_one_c = torch.compile(add_one, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": 0}}) + assert add_one(data.clone())["a", "c"] == 1 + assert add_one_c(data.clone())["a", "c"] == 1 + assert add_one_c(data) is data + + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) + def test_td_index(self, index_type, mode): + if index_type == "slice": + + def add_one(td): + return td[:2] + 1 + + elif index_type == "tensor": + + def add_one(td): + return td[torch.tensor([0, 1])] + 1 + + elif index_type == "int": + + def add_one(td): + return td[0] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + if index_type == "int": + assert (add_one(data)["a", "b"] == 1).all() + assert (add_one_c(data)["a", "b"] == 1).all() + assert add_one_c(data).shape == torch.Size([]) + else: + assert (add_one(data)["a", "b"] == torch.arange(1, 3)).all() + assert (add_one_c(data)["a", "b"] == torch.arange(1, 3)).all() + assert add_one_c(data).shape == torch.Size([2]) + + def test_stack(self, mode): + def stack_tds(td0, td1): + # return TensorDict.stack([td0, td1]) + return torch.stack([td0, td1]) + + stack_tds_c = torch.compile(stack_tds, fullgraph=True, mode=mode) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() + + def test_cat(self, mode): + def cat_tds(td0, td1): + # return TensorDict.cat([td0, td1]) + return torch.cat([td0, td1]) + + cat_tds_c = torch.compile(cat_tds, fullgraph=True, mode=mode) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() + + def test_reshape(self, mode): + def reshape(td): + return td.reshape(2, 2) + + reshape_c = torch.compile(reshape, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + data_reshape = reshape(data) + _ = reshape_c(data) + data_reshape_c = reshape_c(data) + assert (data_reshape == data_reshape_c).all() + + def test_view(self, mode): + def view(td): + return td.view(2, 2) + + view_c = torch.compile(view, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + data_view = view(data) + _ = view_c(data) + data_view_c = view_c(data) + assert (data_view == data_view_c).all() + + def test_transpose(self, mode): + def transpose(td): + return td.transpose(0, 1) + + transpose_c = torch.compile(transpose, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": torch.arange(6).view(2, 3)}}, [2, 3]) + data_transpose = transpose(data) + _ = transpose_c(data) + data_transpose_c = transpose_c(data) + assert (data_transpose == data_transpose_c).all() + + def test_unbind(self, mode): + def unbind(td): + return td.unbind(0) + + unbind_c = torch.compile(unbind, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (unbind(data)[-1] == unbind_c(data)[-1]).all() + + def test_items(self, mode): + def items(td): + keys, vals = zip(*td.items(True, True)) + return keys, vals + + items_c = torch.compile(items, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + keys, vals = items(data) + keys_c, vals_c = items_c(data) + + def assert_eq(x, y): + assert (x == y).all() + + assert keys == keys_c + torch.utils._pytree.tree_map(assert_eq, vals, vals_c) + + @pytest.mark.parametrize("recurse", [True, False]) + @pytest.mark.parametrize("lock", [True, False]) + def test_clone(self, recurse, lock, mode): + def clone(td: TensorDict): + return td.clone(recurse=recurse) + + clone_c = torch.compile(clone, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": 0, "c": 1}}) + if lock: + data = data.lock_() + data_c = clone(data) + _ = clone_c(data) + data_c_c = clone_c(data) + assert_close(data_c, data_c_c) + assert clone_c(data) is not data + if recurse: + assert clone_c(data)["a", "b"] is not data["a", "b"] + else: + assert clone_c(data)["a", "b"] is data["a", "b"] + + @pytest.mark.parametrize("recurse", [True, False]) + def test_flatten_keys(self, recurse, mode): + def flatten_keys(td: TensorDict): + return td.flatten_keys() + + flatten_keys_c = torch.compile(flatten_keys, fullgraph=True, mode=mode) + data = TensorDict({"a": {"b": 0, "c": 1}}) + data_f = flatten_keys(data) + _ = flatten_keys(data) + data_f_c = flatten_keys(data) + assert_close(data_f, data_f_c) + assert flatten_keys_c(data) is not data + assert flatten_keys_c(data)["a.b"] is data["a", "b"] + + @pytest.mark.parametrize("recurse", [True, False]) + def test_unflatten_keys(self, recurse, mode): + def unflatten_keys(td: TensorDict): + return td.unflatten_keys() + + unflatten_keys_c = torch.compile(unflatten_keys, fullgraph=True, mode=mode) + data = TensorDict({"a.b": 0, "a.c": 1}) + data_t = unflatten_keys(data) + _ = unflatten_keys_c(data) + data_t_c = unflatten_keys_c(data) + assert_close(data_t, data_t_c) + assert unflatten_keys_c(data) is not data + assert unflatten_keys_c(data)["a", "b"] is data["a.b"] + + def test_names(self, mode): + import torch._dynamo.exc + + def make_td_with_names(data): + return TensorDict(data, batch_size=[1, 2], names=["d0", "d1"]) + + data_dict = { + "a": torch.randn(1, 2, 3), + "b": torch.zeros(1, 2, 3, dtype=torch.bool), + } + make_td_with_names_c = torch.compile( + make_td_with_names, fullgraph=True, mode=mode + ) + make_td_with_names(data_dict) + with pytest.raises(torch._dynamo.exc.Unsupported): + make_td_with_names_c(data_dict) + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda required to test device casting" + ) + @pytest.mark.parametrize("has_device", [True, False]) + def test_to(self, has_device, mode): + device = "cuda:0" + + def test_to_device(td): + return td.to(device) + + td = TensorDict( + {"a": torch.randn(1, 2, 3), "b": torch.zeros(1, 2, 3, dtype=torch.bool)}, + batch_size=[1, 2], + device="cpu" if has_device else None, + ) + test_to_device_c = torch.compile(test_to_device, fullgraph=True, mode=mode) + # td_device = test_to_device(td) + _ = test_to_device_c(td) + td_device_c = test_to_device_c(td) + assert td_device_c.batch_size == td.batch_size + assert td_device_c.device == torch.device(device) + + def test_lock(self, mode): + def locked_op(td): + # Adding stuff uses cache, check that this doesn't break + td2 = td + 1 + td3 = td + td2 + return td3 + + td = TensorDict( + {"a": torch.randn(1, 2, 3), "b": torch.zeros(1, 2, 3, dtype=torch.bool)}, + batch_size=[1, 2], + device="cpu", + lock=True, + ) + locked_op_c = torch.compile(locked_op, fullgraph=True, mode=mode) + td_op = locked_op(td) + # no warning the second time this is run + with pytest.warns( + UserWarning, match="Using lock_" + ) if mode is None else contextlib.nullcontext(): + _ = locked_op_c(td) + td_op_c = locked_op_c(td) + assert (td_op == td_op_c).all() + + def test_lock_inplace(self, mode): + def locked_op(td): + # Adding stuff uses cache, check that this doesn't break + td += 1 + td += td + return td + + td = TensorDict( + {"a": torch.randn(1, 2, 3), "b": torch.ones(1, 2, 3, dtype=torch.int64)}, + batch_size=[1, 2], + device="cpu", + lock=True, + ) + locked_op_c = torch.compile(locked_op, fullgraph=True, mode=mode) + td_op = locked_op(td) + # no warning the second time this is run + _ = locked_op_c(td) + td_op_c = locked_op_c(td) + assert (td_op == td_op_c).all() + + # Memmap is currently not supported + # def test_memmap(self, mode, tmpdir): + # def locked_op(td): + # # Adding stuff uses cache, check that this doesn't break + # return td.apply(lambda x: x+1) + # + # td = TensorDict( + # {"a": torch.randn(1, 2, 3), "b": torch.ones(1, 2, 3, dtype=torch.int64)}, + # batch_size=[1, 2], + # device="cpu", + # ).memmap_(tmpdir) + # locked_op_c = torch.compile(locked_op, fullgraph=True, mode=mode) + # td_op = locked_op(td) + # # no warning the second time this is run + # _ = locked_op_c(td) + # td_op_c = locked_op_c(td) + # assert (td_op == td_op_c).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index be6490646..f933bac0d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -9435,7 +9435,11 @@ def test_memmap_stack(self, tmpdir, json_serializable, device): def test_memmap_stack_updates(self, tmpdir): data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0) - data = torch.stack([data] * 3).clone() + assert is_non_tensor(data) + data = torch.stack([data] * 3) + assert is_non_tensor(data) + data = data.clone() + assert is_non_tensor(data) data.memmap_(tmpdir) data_recon = TensorDict.load_memmap(tmpdir) assert data.tolist() == data_recon.tolist()