From 6b189b157c5f009068db73f5fca333f76c2436ca Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 31 Jul 2023 18:35:45 +0100 Subject: [PATCH 1/6] init --- tensordict/tensordict.py | 136 ++++++++++- tensordict/tensorstack.py | 475 ++++++++++++++++++++++++++++++++++++++ test/test_tensorstack.py | 199 ++++++++++++++++ 3 files changed, 800 insertions(+), 10 deletions(-) create mode 100644 tensordict/tensorstack.py create mode 100644 test/test_tensorstack.py diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index f0eb07bf4..abf3cf27f 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -9,6 +9,7 @@ import collections import functools import numbers +import operator import os import re import textwrap @@ -2055,6 +2056,93 @@ def __ne__(self, other: object) -> TensorDictBase: # def __hash__(self): # ... + def __and__(self, other): + """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. + + Returns: + a new TensorDict instance with all tensors are boolean + tensors of the same shape as the original tensors. + + """ + if is_tensorclass(other): + return other & self + if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__): + keys1 = set(self.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + d = {} + for key, item1 in self.items(): + d[key] = item1 & other.get(key) + return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + if isinstance(other, (numbers.Number, Tensor)): + return TensorDict( + {key: value & other for key, value in self.items()}, + self.batch_size, + device=self.device, + ) + raise NotImplementedError( + f"Cannot compare objects of type {type(self)} and {type(other)}" + ) + + def __or__(self, other): + """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. + + Returns: + a new TensorDict instance with all tensors are boolean + tensors of the same shape as the original tensors. + + """ + if is_tensorclass(other): + return other | self + if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__): + keys1 = set(self.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + d = {} + for key, item1 in self.items(): + d[key] = item1 | other.get(key) + return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + if isinstance(other, (numbers.Number, Tensor)): + return TensorDict( + {key: value | other for key, value in self.items()}, + self.batch_size, + device=self.device, + ) + raise NotImplementedError( + f"Cannot compare objects of type {type(self)} and {type(other)}" + ) + + def __xor__(self, other): + """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. + + Returns: + a new TensorDict instance with all tensors are boolean + tensors of the same shape as the original tensors. + + """ + if is_tensorclass(other): + return other ^ self + if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__): + keys1 = set(self.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + d = {} + for key, item1 in self.items(): + d[key] = item1 ^ other.get(key) + return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + if isinstance(other, (numbers.Number, Tensor)): + return TensorDict( + {key: value ^ other for key, value in self.items()}, + self.batch_size, + device=self.device, + ) + raise NotImplementedError( + f"Cannot compare objects of type {type(self)} and {type(other)}" + ) + def __eq__(self, other: object) -> TensorDictBase: """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. @@ -7075,14 +7163,28 @@ def all(self, dim: int = None) -> bool | TensorDictBase: "smaller than tensordict.batch_dims" ) if dim is not None: - # TODO: we need to adapt this to LazyStackedTensorDict too if dim < 0: dim = self.batch_dims + dim - return TensorDict( - source={key: value.all(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, + if dim > self.stack_dim: + dim = dim - 1 + new_stack_dim = self.stack_dim + elif dim == self.stack_dim: + if len(self.tensordicts) == 1: + return self.tensordicts[0].apply(lambda x: x.bool()) + + val = functools.reduce( + operator.and_, + [td.apply(lambda x: x.bool()) for td in self.tensordicts], + ) + return val + else: + new_stack_dim = self.stack_dim - 1 + + out = LazyStackedTensorDict( + *[td.all(dim) for td in self.tensordicts], stack_dim=new_stack_dim ) + out._td_dim_name = self._td_dim_name + return out return all(value.all() for value in self.tensordicts) def any(self, dim: int = None) -> bool | TensorDictBase: @@ -7092,14 +7194,28 @@ def any(self, dim: int = None) -> bool | TensorDictBase: "smaller than tensordict.batch_dims" ) if dim is not None: - # TODO: we need to adapt this to LazyStackedTensorDict too if dim < 0: dim = self.batch_dims + dim - return TensorDict( - source={key: value.any(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, + if dim > self.stack_dim: + dim = dim - 1 + new_stack_dim = self.stack_dim + elif dim == self.stack_dim: + if len(self.tensordicts) == 1: + return self.tensordicts[0].apply(lambda x: x.bool()) + + val = functools.reduce( + operator.or_, + [td.apply(lambda x: x.bool()) for td in self.tensordicts], + ) + return val + else: + new_stack_dim = self.stack_dim - 1 + + out = LazyStackedTensorDict( + *[td.any(dim) for td in self.tensordicts], stack_dim=new_stack_dim ) + out._td_dim_name = self._td_dim_name + return out return any(value.any() for value in self.tensordicts) def _send(self, dst: int, _tag: int = -1, pseudo_rand: bool = False) -> int: diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py new file mode 100644 index 000000000..5ef112bba --- /dev/null +++ b/tensordict/tensorstack.py @@ -0,0 +1,475 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import functools +import operator +from copy import copy +from typing import Sequence + +import numpy as np +import torch +from tensordict.tensordict import _broadcast_tensors, _is_number + +from tensordict.utils import convert_ellipsis_to_idx +from torch import Tensor + + +def _get_shape( + tensor_data, +): + shape = list(tensor_data[0].shape) + for t in tensor_data[1:]: + tshape = t.shape + for i, (s1, s2) in enumerate(list(zip(shape, tshape))): + shape[i] = s1 if s1 == s2 else -1 + return shape + + +def _get_shape_nested( + tensor_data, +): + out = [] + for i in range(tensor_data.ndim): + try: + s = tensor_data.size(i) + except Exception: + s = -1 + out.append(s) + shape = torch.Size(out) + return shape + + +def _elementiwse_broadcast(func): + func_name = func.__name__ + + def new_func(self, other): + if self._nested: + return type(self)( + getattr(torch.Tensor, func_name)(self.tensors, other), + stack_dim=self.stack_dim, + ) + if isinstance(other, (torch.Tensor,)): + shape = torch.broadcast_shapes(other.shape, self._shape_no0) + if shape != other.shape: + other = other.expand(shape) + if shape != self._shape_no0: + self_expand = self.expand(shape) + else: + self_expand = self + other = other.unbind(self_expand.stack_dim) + elif isinstance(other, (LazyStackedTensors,)): + shape = torch.broadcast_shapes(other._shape_no0, self._shape_no0) + if shape != other._shape_no0: + other = other.expand(shape) + if shape != self._shape_no0: + self_expand = self.expand(shape) + else: + self_expand = self + other = other.unbind(self_expand.stack_dim) + else: + self_expand = self + other = (other,) * self.n + return type(self)( + [ + getattr(torch.Tensor, func_name)(t, _other) + for t, _other in zip(self_expand.tensors, other) + ], + self.stack_dim, + ) + + return new_func + + +class LazyStackedTensors: + def __init__(self, tensors, stack_dim=0): + self.tensors = tensors + self.stack_dim = stack_dim + self._nested = isinstance(tensors, torch.Tensor) and tensors.is_nested + self._shape = self._get_shape() + + @property + def shape(self): + return self._shape + + @property + def _shape_no0(self): + return torch.Size([s if s >= 0 else 1 for s in self._shape]) + + def _get_shape(self): + tensors = self.tensors + if self._nested: + shape = _get_shape_nested(tensors) + if self.stack_dim < 0: + self.stack_dim = len(shape) + self.stack_dim + if self.stack_dim > len(shape) or self.stack_dim < 0: + raise RuntimeError + if self.stack_dim != 0: + n, *shape = list(shape) + shape.insert(self.stack_dim, n) + else: + shape = _get_shape(tensors) + if self.stack_dim < 0: + self.stack_dim = len(shape) + self.stack_dim + 1 + if self.stack_dim > len(shape) or self.stack_dim < 0: + raise RuntimeError + shape.insert(self.stack_dim, len(tensors)) + return torch.Size(shape) + + def get_nestedtensor(self): + return torch.nested.nested_tensor(list(self.tensors)) + + def as_nestedtensor(self): + if self._nested: + return self + return type(self)(self.get_nestedtensor(), stack_dim=self.stack_dim) + + @classmethod + def from_nested_tensor(cls, nt, stack_dim=0): + return cls(nt, stack_dim=stack_dim) + + def __getitem__(self, index): + split_index = self._split_index(index) + converted_idx = split_index["index_dict"] + isinteger = split_index["isinteger"] + has_bool = split_index["has_bool"] + is_nd_tensor = split_index["is_nd_tensor"] + num_single = split_index.get("num_single", 0) + num_none = split_index.get("num_none", 0) + num_squash = split_index.get("num_squash", 0) + if has_bool: + mask_unbind = split_index["individual_masks"] + cat_dim = split_index["mask_loc"] - num_single + out = [] + if mask_unbind[0].ndim == 0: + # we can return a stack + for (i, _idx), mask in zip(converted_idx.items(), mask_unbind): + if mask.any(): + if mask.all() and self.tensors[i].ndim == 0: + out.append(self.tensors[i]) + else: + out.append(self.tensors[i][_idx]) + out[-1] = out[-1].squeeze(cat_dim) + return LazyStackedTensors(out, cat_dim) + else: + for (i, _idx) in converted_idx.items(): + self_idx = (slice(None),) * split_index["mask_loc"] + (i,) + out.append(self[self_idx][_idx]) + return torch.cat(out, cat_dim) + elif is_nd_tensor: + new_stack_dim = self.stack_dim - num_single + num_none + return LazyStackedTensors( + [self[idx] for idx in converted_idx.values()], new_stack_dim + ) + else: + if isinteger: + for ( + i, + _idx, + ) in ( + converted_idx.items() + ): # for convenience but there's only one element + out = self.tensors[i] + if _idx is not None and _idx != (): + out = out[_idx] + return out + else: + out = [] + new_stack_dim = self.stack_dim - num_single + num_none - num_squash + for (i, _idx) in converted_idx.items(): + out.append(self.tensors[i][_idx]) + out = LazyStackedTensors(out, new_stack_dim) + return out + + def _split_index(self, index): + """Given a tuple index, split it in as many indices as the number of tensordicts. + + Returns: + a dictionary with {index-of-td: index-within-td} + the number of single dim indices until stack dim + a boolean indicating if the index along the stack dim is an integer + """ + if not isinstance(index, tuple): + index = (index,) + index = convert_ellipsis_to_idx(index, self.shape) + index = _broadcast_tensors(index) + out = [] + num_single = 0 + num_none = 0 + isinteger = False + is_nd_tensor = False + cursor = 0 # the dimension cursor + selected_td_idx = range(self.n) + has_bool = False + num_squash = 0 + for i, idx in enumerate(index): # noqa: B007 + cursor_incr = 1 + if idx is None: + out.append(None) + num_none += cursor <= self.stack_dim + continue + if cursor == self.stack_dim: + # we need to check which tds need to be indexed + if isinstance(idx, slice) or _is_number(idx): + selected_td_idx = range(self.n)[idx] + if not isinstance(selected_td_idx, range): + isinteger = True + selected_td_idx = [selected_td_idx] + elif isinstance(idx, (list, range)): + selected_td_idx = idx + elif isinstance(idx, (torch.Tensor, np.ndarray)): + if idx.dtype in (np.dtype("bool"), torch.bool): + # we mark that we need to dispatch the indices across stack idx + has_bool = True + # split mask along dim + individual_masks = idx = idx.unbind(0) + selected_td_idx = range(self.n) + out.append(idx) + split_dim = self.stack_dim - num_single + mask_loc = i + else: + if isinstance(idx, np.ndarray): + idx = torch.tensor(idx) + is_nd_tensor = True + selected_td_idx = range(len(idx)) + out.append(idx.unbind(0)) + else: + raise TypeError(f"Invalid index type: {type(idx)}.") + else: + if _is_number(idx) and cursor < self.stack_dim: + num_single += 1 + if isinstance( + idx, + ( + int, + slice, + list, + range, + ), + ): + out.append(idx) + elif isinstance(idx, (np.ndarray, torch.Tensor)): + if idx.dtype in (np.dtype("bool"), torch.bool): + cursor_incr = idx.ndim + if cursor < self.stack_dim: + num_squash += cursor_incr - 1 + if ( + cursor < self.stack_dim + and cursor + cursor_incr > self.stack_dim + ): + # we mark that we need to dispatch the indices across stack idx + has_bool = True + # split mask along dim + # relative_stack_dim = self.stack_dim - cursor - cursor_incr + individual_masks = idx = idx.unbind(0) + selected_td_idx = range(self.shape[i]) + split_dim = cursor - num_single + mask_loc = i + out.append(idx) + else: + raise TypeError(f"Invalid index type: {type(idx)}.") + cursor += cursor_incr + if has_bool: + out = tuple( + tuple(idx if not isinstance(idx, tuple) else idx[i] for idx in out) + for i in selected_td_idx + ) + return { + "index_dict": {i: out[i] for i in selected_td_idx}, + "num_single": num_single, + "isinteger": isinteger, + "has_bool": has_bool, + "individual_masks": individual_masks, + "split_dim": split_dim, + "mask_loc": mask_loc, + "is_nd_tensor": is_nd_tensor, + "num_none": num_none, + "num_squash": num_squash, + } + elif is_nd_tensor: + + def isindexable(idx): + if isinstance(idx, (torch.Tensor, np.ndarray)): + if idx.dtype in (torch.bool, np.dtype("bool")): + return False + return True + if isinstance(idx, (tuple, list, range)): + return True + return False + + out = tuple( + tuple(idx if not isindexable(idx) else idx[i] for idx in out) + for i in selected_td_idx + ) + return { + "index_dict": dict(enumerate(out)), + "num_single": num_single, + "isinteger": isinteger, + "has_bool": has_bool, + "is_nd_tensor": is_nd_tensor, + "num_none": num_none, + "num_squash": num_squash, + } + return { + "index_dict": {i: tuple(out) for i in selected_td_idx}, + "num_single": num_single, + "isinteger": isinteger, + "has_bool": has_bool, + "is_nd_tensor": is_nd_tensor, + "num_none": num_none, + "num_squash": num_squash, + } + + @_elementiwse_broadcast + def __add__(self, other): + ... + + @_elementiwse_broadcast + def __sub__(self, other): + ... + + @_elementiwse_broadcast + def __truediv__(self, other): + ... + + @_elementiwse_broadcast + def __div__(self, other): + ... + + @_elementiwse_broadcast + def __mul__(self, other): + ... + + @_elementiwse_broadcast + def __eq__(self, other): + ... + + @_elementiwse_broadcast + def __ne__(self, other): + ... + + @property + def n(self): + return self.shape[self.stack_dim] + + def __len__(self): + return self._shape[0] + + @property + def ndim(self): + return len(self.shape) + + def ndimension(self): + return self.ndim + + def expand(self, *shape: int): + dims = self.ndim + + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = tuple(shape[0]) + + # new shape dim check + if len(shape) < len(self.shape): + raise RuntimeError( + "the number of sizes provided ({shape_dim}) must be greater or equal to the number of " + "dimensions in the tensor ({t_dim})".format( + shape_dim=len(shape), t_dim=dims + ) + ) + + # new shape compatability check + for old_dim, new_dim in zip(self.shape, shape[-dims:]): + if old_dim not in (1, -1) and new_dim != old_dim: + raise RuntimeError( + "Incompatible expanded shape: The expanded shape length at non-singleton dimension should be same " + "as the original length. target_shape = {new_shape}, existing_shape = {old_shape}".format( + new_shape=shape, old_shape=self.shape + ) + ) + + stack_dim = len(shape) + self.stack_dim - self.ndimension() + new_shape_t = [v for i, v in enumerate(shape) if i != stack_dim] + tensors = [ + t.expand(*torch.broadcast_shapes(new_shape_t, t.shape)) + for t in self.tensors + ] + return type(self)(tensors, stack_dim=stack_dim) + + def unbind(self, dim: int): + if dim < 0: + dim = self.shape + dim + if dim < 0 or dim >= self.ndim: + raise ValueError( + f"Cannot unbind along dimension {dim} with shape {self.shape}." + ) + if dim == self.stack_dim: + if self._nested: + return self.tensors.unbind(0) + return tuple(self.tensors) + else: + # return a stack of unbound tensordicts + out = [] + new_dim = dim if dim < self.stack_dim else dim - 1 + new_stack_dim = ( + self.stack_dim if dim > self.stack_dim else self.stack_dim - 1 + ) + for t in self.tensors: + out.append(t.unbind(new_dim)) + return tuple(LazyStackedTensors(vals, new_stack_dim) for vals in zip(*out)) + + def all(self, dim: int = None): + if dim is not None and (dim >= self.ndim or dim < -self.ndim): + raise RuntimeError( + "dim must be greater than or equal to -tensordict.batch_dims and " + "smaller than tensordict.batch_dims" + ) + if dim is not None: + if dim < 0: + dim = self.ndim + dim + if dim > self.stack_dim: + dim = dim - 1 + new_stack_dim = self.stack_dim + elif dim == self.stack_dim: + if len(self.tensors) == 1: + return self.tensors[0].bool() + + val = functools.reduce(operator.and_, [t.bool() for t in self.tensors]) + return val + else: + new_stack_dim = self.stack_dim - 1 + + out = LazyStackedTensors( + [t.all(dim) for t in self.tensors], stack_dim=new_stack_dim + ) + return out + return all(value.all() for value in self.tensors) + + def any(self, dim: int = None): + if dim is not None and (dim >= self.ndim or dim < -self.ndim): + raise RuntimeError( + "dim must be greater than or equal to -tensordict.batch_dims and " + "smaller than tensordict.batch_dims" + ) + if dim is not None: + if dim < 0: + dim = self.ndim + dim + if dim > self.stack_dim: + dim = dim - 1 + new_stack_dim = self.stack_dim + elif dim == self.stack_dim: + if len(self.tensors) == 1: + return self.tensors[0].bool() + + val = functools.reduce(operator.or_, [t.bool() for t in self.tensors]) + return val + else: + new_stack_dim = self.stack_dim - 1 + + out = LazyStackedTensors( + [t.any(dim) for t in self.tensors], stack_dim=new_stack_dim + ) + return out + return any(value.any() for value in self.tensors) + + def __repr__(self): + return f"{self.__class__.__name__}({self.get_nestedtensor()})" diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py new file mode 100644 index 000000000..c11344b69 --- /dev/null +++ b/test/test_tensorstack.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch + +from tensordict.tensorstack import LazyStackedTensors as TensorStack + + +def _tensorstack(stack_dim, nt, init="randint"): + torch.manual_seed(0) + if init == "randint": + x = torch.randint(10, (3, 1, 5)) + y = torch.randint(10, (3, 2, 5)) + z = torch.randint(10, (3, 3, 5)) + elif init == "zeros": + x = torch.zeros((3, 1, 5)) + y = torch.zeros((3, 2, 5)) + z = torch.zeros((3, 3, 5)) + elif init == "ones": + x = torch.ones((3, 1, 5)) + y = torch.ones((3, 2, 5)) + z = torch.ones((3, 3, 5)) + if not nt: + t = TensorStack([x, y, z], stack_dim=stack_dim) + else: + t = TensorStack(torch.nested.nested_tensor([x, y, z]), stack_dim=stack_dim) + return t, (x, y, z) + + +class TestTensorStack: + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -1, -2, -3, -4]) + @pytest.mark.parametrize("nt", [True, False]) + def test_indexing_int(self, stack_dim, nt): + t, (x, y, z) = _tensorstack(stack_dim, nt) + sd = stack_dim if stack_dim >= 0 else 4 + stack_dim + init_slice = (slice(None),) * sd + assert (t[init_slice + (0,)] == x).all() + assert (t[init_slice + (1,)] == y).all() + assert (t[init_slice + (2,)] == z).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -1, -2, -3, -4]) + @pytest.mark.parametrize("nt", [True, False]) + def test_all(self, stack_dim, nt): + t, (x, y, z) = _tensorstack(stack_dim, nt, "zeros") + # sd = stack_dim if stack_dim >= 0 else 4 + stack_dim + # init_slice = (slice(None),) * sd + assert not t.all() + assert not t.any() + t, (x, y, z) = _tensorstack(stack_dim, nt, "ones") + # sd = stack_dim if stack_dim >= 0 else 4 + stack_dim + # init_slice = (slice(None),) * sd + assert t.all() + assert t.any() + + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -1, -2, -3, -4]) + def test_indexing_slice(self, stack_dim, nt): + t, (x, y, z) = _tensorstack(stack_dim, nt) + sd = stack_dim if stack_dim >= 0 else 4 + stack_dim + init_slice = (slice(None),) * sd + assert (t[init_slice + (slice(1),)][init_slice + (0,)] == x).all(), ( + t[init_slice + (slice(3),)][0], + x, + ) + assert (t[init_slice + (slice(2),)][init_slice + (1,)] == y).all() + assert (t[init_slice + (slice(3),)][init_slice + (2,)] == z).all() + assert (t[init_slice + (slice(-3, None),)][init_slice + (0,)] == x).all() + assert (t[init_slice + (slice(-2, None),)][init_slice + (0,)] == y).all() + assert (t[init_slice + (slice(-1, None),)][init_slice + (0,)] == z).all() + + assert ( + TensorStack([x, y], stack_dim=t.stack_dim) == t[init_slice + (slice(2),)] + ).all() + assert ( + TensorStack([y, z], stack_dim=t.stack_dim) + == t[init_slice + (slice(-2, None),)] + ).all() + assert ( + TensorStack([x, z], stack_dim=t.stack_dim) + == t[init_slice + (slice(0, 3, 2),)] + ).all() + + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -1, -2, -3, -4]) + def test_indexing_range(self, stack_dim, nt): + t, (x, y, z) = _tensorstack(stack_dim, nt) + sd = stack_dim if stack_dim >= 0 else 4 + stack_dim + init_slice = (slice(None),) * sd + assert (t[init_slice + (slice(1),)][init_slice + (0,)] == x).all(), ( + t[init_slice + (slice(3),)][0], + x, + ) + assert (t[init_slice + (range(2),)][init_slice + (1,)] == y).all() + assert (t[init_slice + (range(3),)][init_slice + (2,)] == z).all() + assert (t[init_slice + (range(-3, 1),)][init_slice + (0,)] == x).all() + assert (t[init_slice + (range(-2, 1),)][init_slice + (0,)] == y).all() + assert (t[init_slice + (range(-1, 1),)][init_slice + (0,)] == z).all() + + assert ( + TensorStack([x, y], stack_dim=t.stack_dim) == t[init_slice + (range(2),)] + ).all() + assert ( + TensorStack([y, z], stack_dim=t.stack_dim) + == t[init_slice + (range(-2, 0),)] + ).all() + assert ( + TensorStack([x, z], stack_dim=t.stack_dim) + == t[init_slice + (range(0, 3, 2),)] + ).all() + + # def test_indexing_tensor(self, _tensorstack): + # t, (x, y, z) = _tensorstack + # assert (t[torch.tensor([0, 2])][0] == x).all() + # assert (t[torch.tensor([0, 2])][1] == z).all() + # assert (t[torch.tensor([0, 2, 0, 2])][2] == x).all() + # assert (t[torch.tensor([0, 2, 0, 2])][3] == z).all() + # + # assert (t[torch.tensor([[0, 2], [0, 2]])][0, 0] == x).all() + # assert (t[torch.tensor([[0, 2], [0, 2]])][0, 1] == z).all() + # assert (t[torch.tensor([[0, 2], [0, 2]])][1, 0] == x).all() + # assert (t[torch.tensor([[0, 2], [0, 2]])][1, 1] == z).all() + # + # def test_indexing_composite(self, _tensorstack): + # _, (x, y, z) = _tensorstack + # t = TensorStack.from_tensors([[x, y, z], [x, y, z]]) + # assert (t[0, 0] == x).all() + # assert (t[torch.tensor([0]), torch.tensor([0])] == x).all() + # assert (t[torch.tensor([0]), torch.tensor([1])] == y).all() + # assert (t[torch.tensor([0]), torch.tensor([2])] == z).all() + # assert (t[:, torch.tensor([0])] == x).all() + # assert (t[:, torch.tensor([1])] == y).all() + # assert (t[:, torch.tensor([2])] == z).all() + # assert ( + # t[torch.tensor([0]), torch.tensor([1, 2])] + # == TensorStack.from_tensors([y, z]) + # ).all() + # with pytest.raises(IndexError, match="Cannot index along"): + # assert ( + # t[..., torch.tensor([1, 2]), :, :, :] + # == TensorStack.from_tensors([y, z]) + # ).all() + # + # @pytest.mark.parametrize( + # "op", + # ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"], + # ) + # def test_elementwise(self, _tensorstack, op): + # t, (x, y, z) = _tensorstack + # t2 = getattr(t, op)(2) + # torch.testing.assert_close(t2[0], getattr(x, op)(2)) + # torch.testing.assert_close(t2[1], getattr(y, op)(2)) + # torch.testing.assert_close(t2[2], getattr(z, op)(2)) + # t2 = getattr(t, op)(torch.ones(5) * 2) + # torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2)) + # torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2)) + # torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2)) + # # check broadcasting + # assert t2[0].shape == x.shape + # v = torch.ones(2, 1, 1, 1, 5) * 2 + # t2 = getattr(t, op)(v) + # assert t2.shape == torch.Size([2, 3, 3, -1, 5]) + # torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0])) + # torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0])) + # torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0])) + # # check broadcasting + # assert t2[:, 0].shape == torch.Size((2, *x.shape)) + # + # def test_permute(self): + # w = torch.randint(10, (3, 5, 5)) + # x = torch.randint(10, (3, 4, 5)) + # y = torch.randint(10, (3, 5, 5)) + # z = torch.randint(10, (3, 4, 5)) + # ts = TensorStack.from_tensors([[w, x], [y, z]]) + # tst = ts.permute(1, 0, 2, 3, 4) + # assert (tst[0, 1] == ts[1, 0]).all() + # assert (tst[1, 0] == ts[0, 1]).all() + # assert (tst[1, 1] == ts[1, 1]).all() + # assert (tst[0, 0] == ts[0, 0]).all() + # + # def test_transpose(self): + # w = torch.randint(10, (3, 5, 5)) + # x = torch.randint(10, (3, 4, 5)) + # y = torch.randint(10, (3, 5, 5)) + # z = torch.randint(10, (3, 4, 5)) + # ts = TensorStack.from_tensors([[w, x], [y, z]]) + # tst = ts.transpose(1, 0) + # assert (tst[0, 1] == ts[1, 0]).all() + # assert (tst[1, 0] == ts[0, 1]).all() + # assert (tst[1, 1] == ts[1, 1]).all() + # assert (tst[0, 0] == ts[0, 0]).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 43964fb2bc5fc653c4a926e95990a0441c52739c Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 1 Aug 2023 11:13:37 +0100 Subject: [PATCH 2/6] new attempt --- tensordict/tensorstack.py | 8 ++++++++ test/test_tensorstack.py | 30 +++++++++++++++++++----------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 5ef112bba..fc54e7dde 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -471,5 +471,13 @@ def any(self, dim: int = None): return out return any(value.any() for value in self.tensors) + def transpose(self, dim0, dim1): + if dim0 < 0: + dim0 = self.ndim + dim0 + if dim1 < 0: + dim1 = self.ndim + dim1 + if dim0 < 0 or dim1 < 0 or dim0 >= self.ndim or dim1 > self.ndim: + + def __repr__(self): return f"{self.__class__.__name__}({self.get_nestedtensor()})" diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index c11344b69..7091624bf 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -112,17 +112,25 @@ def test_indexing_range(self, stack_dim, nt): == t[init_slice + (range(0, 3, 2),)] ).all() - # def test_indexing_tensor(self, _tensorstack): - # t, (x, y, z) = _tensorstack - # assert (t[torch.tensor([0, 2])][0] == x).all() - # assert (t[torch.tensor([0, 2])][1] == z).all() - # assert (t[torch.tensor([0, 2, 0, 2])][2] == x).all() - # assert (t[torch.tensor([0, 2, 0, 2])][3] == z).all() - # - # assert (t[torch.tensor([[0, 2], [0, 2]])][0, 0] == x).all() - # assert (t[torch.tensor([[0, 2], [0, 2]])][0, 1] == z).all() - # assert (t[torch.tensor([[0, 2], [0, 2]])][1, 0] == x).all() - # assert (t[torch.tensor([[0, 2], [0, 2]])][1, 1] == z).all() + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -1, -2, -3, -4]) + def test_indexing_tensor(self, stack_dim, nt): + t, (x, y, z) = _tensorstack(stack_dim, nt) + sd = stack_dim if stack_dim >= 0 else 4 + stack_dim + init_slice = (slice(None),) * sd + assert (t[init_slice + (slice(1),)][init_slice + (0,)] == x).all(), ( + t[init_slice + (slice(3),)][0], + x, + ) + assert (t[init_slice + (torch.tensor([0, 2]),)][init_slice + (0,)] == x).all() + assert (t[init_slice + (torch.tensor([0, 2]),)][init_slice + (1,)] == z).all() + assert (t[init_slice + (torch.tensor([0, 2, 0, 2]),)][init_slice + (2,)] == x).all() + assert (t[init_slice + (torch.tensor([0, 2, 0, 2]),)][init_slice + (3,)] == z).all() + + assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (0, 0)] == x).all() + assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (0, 1)] == z).all() + assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (1, 0)] == x).all() + assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (1, 1)] == z).all() # # def test_indexing_composite(self, _tensorstack): # _, (x, y, z) = _tensorstack From 9e8c325528e0da65d86394b2e688f95593a8b4ff Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 1 Aug 2023 12:04:11 +0100 Subject: [PATCH 3/6] amend --- tensordict/tensorstack.py | 78 +++++++++++++++++++++++++++++-- test/test_tensorstack.py | 98 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 165 insertions(+), 11 deletions(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index fc54e7dde..5b79f1069 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -471,13 +471,83 @@ def any(self, dim: int = None): return out return any(value.any() for value in self.tensors) - def transpose(self, dim0, dim1): + def transpose(self, dim0, dim1=None): + if isinstance(dim0, (list, tuple)) and dim1 is None: + dim0, dim1 = dim0 + elif isinstance(dim0, (list, tuple)): + raise ValueError( + "Expected one of `transpose((dim0, dim1))` or `transpose(dim0, dim1)`." + ) if dim0 < 0: - dim0 = self.ndim + dim0 + newdim0 = self.ndim + dim0 + else: + newdim0 = dim0 if dim1 < 0: - dim1 = self.ndim + dim1 - if dim0 < 0 or dim1 < 0 or dim0 >= self.ndim or dim1 > self.ndim: + newdim1 = self.ndim + dim1 + else: + newdim1 = dim1 + if newdim0 < 0 or newdim1 < 0 or newdim0 >= self.ndim or newdim1 > self.ndim: + raise ValueError( + f"Dimensions {(dim0, dim1)} are incompatible with a tensor of shape {self.shape}." + ) + if newdim0 == newdim1: + return self + if newdim0 == self.stack_dim: + newdim1 = newdim1 if newdim1 < self.stack_dim else newdim1 - 1 + pdim = [i for i in range(self.ndim - 1) if i != newdim1] + pdim.insert(newdim0 - 1, newdim1) + return LazyStackedTensors( + [t.permute(pdim) for t in self.tensors], stack_dim=newdim1 + ) + elif newdim1 == self.stack_dim: + newdim0 = newdim0 if newdim0 < self.stack_dim else newdim0 - 1 + pdim = [i for i in range(self.ndim - 1) if i != newdim0] + pdim.insert(newdim1 - 1, newdim0) + return LazyStackedTensors( + [t.permute(pdim) for t in self.tensors], stack_dim=newdim0 + ) + else: + newdim0 = newdim0 if newdim0 < self.stack_dim else newdim0 - 1 + newdim1 = newdim1 if newdim1 < self.stack_dim else newdim1 - 1 + return LazyStackedTensors( + [t.transpose(newdim1, newdim0) for t in self.tensors], + stack_dim=self.stack_dim, + ) + def permute(self, *permute_dims): + orig_permute_dims = permute_dims + if isinstance(permute_dims[0], (tuple, list)): + if len(permute_dims) == 1: + permute_dims = permute_dims[0] + else: + raise ValueError( + f"Got incompatible argument permute_dims: {orig_permute_dims}." + ) + permute_dims = [p if p >= 0 else self.ndim + p for p in permute_dims] + if any(p < 0 or p >= self.ndim for p in permute_dims): + raise ValueError( + f"Got incompatible argument permute_dims: {orig_permute_dims}." + ) + if len(permute_dims) != self.ndim: + raise ValueError( + f"permute_dims must have the same length as the number of dimensions of the tensor ({self.ndim}): {orig_permute_dims}." + ) + for i in range(self.ndim): + if permute_dims[i] == self.stack_dim: + new_stack_dim = i + break + else: + # unreachable + raise RuntimeError + permute_dims = [ + p if p < self.stack_dim else p - 1 + for p in permute_dims + if p != self.stack_dim + ] + return LazyStackedTensors( + [t.permute(permute_dims) for t in self.tensors], + stack_dim=new_stack_dim, + ) def __repr__(self): return f"{self.__class__.__name__}({self.get_nestedtensor()})" diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index 7091624bf..2cbc85259 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -124,14 +124,98 @@ def test_indexing_tensor(self, stack_dim, nt): ) assert (t[init_slice + (torch.tensor([0, 2]),)][init_slice + (0,)] == x).all() assert (t[init_slice + (torch.tensor([0, 2]),)][init_slice + (1,)] == z).all() - assert (t[init_slice + (torch.tensor([0, 2, 0, 2]),)][init_slice + (2,)] == x).all() - assert (t[init_slice + (torch.tensor([0, 2, 0, 2]),)][init_slice + (3,)] == z).all() + assert ( + t[init_slice + (torch.tensor([0, 2, 0, 2]),)][init_slice + (2,)] == x + ).all() + assert ( + t[init_slice + (torch.tensor([0, 2, 0, 2]),)][init_slice + (3,)] == z + ).all() + + assert ( + t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (0, 0)] == x + ).all() + assert ( + t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (0, 1)] == z + ).all() + assert ( + t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (1, 0)] == x + ).all() + assert ( + t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (1, 1)] == z + ).all() + + @pytest.mark.parametrize( + "transpose", + [(0, 1), (0, -1), (-1, 0), (1, 3), (1, 2), (2, 1), (0, 2), (2, 0), (2, 2)], + ) + @pytest.mark.parametrize("het", [False, True]) + @pytest.mark.parametrize("nt", [False, True]) + def test_transpose(self, het, transpose, nt): + torch.manual_seed(0) + x = torch.randn(6, 5, 4, 3) + if het: + y = torch.randn(6, 5, 2, 3) + else: + y = torch.randn(6, 5, 4, 3) + if nt: + t = TensorStack(torch.nested.nested_tensor([x, y]), stack_dim=2) + else: + t = TensorStack([x, y], stack_dim=2) + + tt = t.transpose(transpose) + with pytest.raises(ValueError): + t.transpose(transpose, 0) + if transpose == (1, 2) or transpose == (2, 1): + assert (tt[:, 0] == x).all() + assert (tt[:, 1] == y).all() + elif transpose == (0, 2) or transpose == (2, 0): + assert (tt[0] == x.permute(1, 0, 2, 3)).all() + assert (tt[1] == y.permute(1, 0, 2, 3)).all() + elif transpose == (2, 2): + assert (t[:, :, 0] == x).all() + assert (t[:, :, 1] == y).all() + elif transpose == (0, 1): + assert (tt[:, :, 0] == x.transpose(0, 1)).all() + assert (tt[:, :, 1] == y.transpose(0, 1)).all() + elif transpose == (0, -1): + assert (tt[:, :, 0] == x.transpose(0, -1)).all() + assert (tt[:, :, 1] == y.transpose(0, -1)).all() + elif transpose == (1, 3): + assert (tt[:, :, 0] == x.transpose(1, 2)).all() + assert (tt[:, :, 1] == y.transpose(1, 2)).all() + + def test_permute(self): + torch.manual_seed(0) + x = torch.zeros(6, 5, 4, 3) + y = torch.zeros(6, 5, 4, 3) + t = TensorStack((x, y), stack_dim=2) + with pytest.raises(ValueError, match="Got incompatible argument permute_dims"): + t.permute((1, 2, 3), 0) + with pytest.raises(ValueError, match="Got incompatible argument permute_dims"): + t.permute((1, 2, 3, 4, 10)) + with pytest.raises(ValueError, match="permute_dims must have the same length"): + t.permute((1, 2, 3, 4)) + stack = torch.stack([x, y], 2) + for _ in range(128): + pdim = torch.randperm(5).tolist() + tp = t.permute(pdim) + assert tp.shape == stack.permute(pdim).shape + assert (tp == stack.permute(pdim)).all() + + @pytest.mark.parametrize("unbind", range(5)) + @pytest.mark.parametrize("nt", [False, True]) + def test_permute(self, unbind, nt): + torch.manual_seed(0) + x = torch.zeros(6, 5, 4, 3) + y = torch.zeros(6, 5, 4, 3) + if nt: + t = TensorStack(torch.nested.nested_tensor([x, y]), stack_dim=2) + else: + t = TensorStack([x, y], stack_dim=2) + stack = torch.stack([x, y], 2) + for v1, v2 in zip(t.unbind(unbind), stack.unbind(unbind)): + assert (v1 == v2).all() - assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (0, 0)] == x).all() - assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (0, 1)] == z).all() - assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (1, 0)] == x).all() - assert (t[init_slice + (torch.tensor([[0, 2], [0, 2]]),)][init_slice + (1, 1)] == z).all() - # # def test_indexing_composite(self, _tensorstack): # _, (x, y, z) = _tensorstack # t = TensorStack.from_tensors([[x, y, z], [x, y, z]]) From 87086960f8a808386640d4abdc596593d11ca6dd Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 2 Aug 2023 09:25:24 +0100 Subject: [PATCH 4/6] amend --- tensordict/tensorstack.py | 26 ++++++++++++++++++-- test/test_tensorstack.py | 50 ++++++++++++++++++++------------------- 2 files changed, 50 insertions(+), 26 deletions(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 5b79f1069..359f2207c 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -45,8 +45,23 @@ def _elementiwse_broadcast(func): def new_func(self, other): if self._nested: + if isinstance(other, torch.Tensor) and not other.is_nested: + shape = torch.broadcast_shapes(other.shape, self._shape_no0) + if shape != other.shape: + other = other.expand(shape) + if shape != self._shape_no0: + self_expand = self.expand(shape).as_nestedtensor() + else: + self_expand = self + sd = self.stack_dim - self.ndim + other = other.unbind(sd) + other = LazyStackedTensors(other, stack_dim=sd).get_nestedtensor() + else: + self_expand = self + # print("op", func_name, "\nt", self.tensors, "\nother", other) + # print("result", getattr(torch.Tensor, func_name)(self.tensors, other)) return type(self)( - getattr(torch.Tensor, func_name)(self.tensors, other), + getattr(torch.Tensor, func_name)(self_expand.tensors, other), stack_dim=self.stack_dim, ) if isinstance(other, (torch.Tensor,)): @@ -58,6 +73,7 @@ def new_func(self, other): else: self_expand = self other = other.unbind(self_expand.stack_dim) + new_stack_dim = self.stack_dim + len(shape) - self.ndim elif isinstance(other, (LazyStackedTensors,)): shape = torch.broadcast_shapes(other._shape_no0, self._shape_no0) if shape != other._shape_no0: @@ -67,15 +83,17 @@ def new_func(self, other): else: self_expand = self other = other.unbind(self_expand.stack_dim) + new_stack_dim = self.stack_dim + len(shape) - self.ndim else: self_expand = self other = (other,) * self.n + new_stack_dim = self.stack_dim return type(self)( [ getattr(torch.Tensor, func_name)(t, _other) for t, _other in zip(self_expand.tensors, other) ], - self.stack_dim, + stack_dim=new_stack_dim, ) return new_func @@ -348,6 +366,10 @@ def __eq__(self, other): def __ne__(self, other): ... + @_elementiwse_broadcast + def __mod__(self, other): + ... + @property def n(self): return self.shape[self.stack_dim] diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index 2cbc85259..bdbd71747 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -236,30 +236,32 @@ def test_permute(self, unbind, nt): # == TensorStack.from_tensors([y, z]) # ).all() # - # @pytest.mark.parametrize( - # "op", - # ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"], - # ) - # def test_elementwise(self, _tensorstack, op): - # t, (x, y, z) = _tensorstack - # t2 = getattr(t, op)(2) - # torch.testing.assert_close(t2[0], getattr(x, op)(2)) - # torch.testing.assert_close(t2[1], getattr(y, op)(2)) - # torch.testing.assert_close(t2[2], getattr(z, op)(2)) - # t2 = getattr(t, op)(torch.ones(5) * 2) - # torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2)) - # torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2)) - # torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2)) - # # check broadcasting - # assert t2[0].shape == x.shape - # v = torch.ones(2, 1, 1, 1, 5) * 2 - # t2 = getattr(t, op)(v) - # assert t2.shape == torch.Size([2, 3, 3, -1, 5]) - # torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0])) - # torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0])) - # torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0])) - # # check broadcasting - # assert t2[:, 0].shape == torch.Size((2, *x.shape)) + @pytest.mark.parametrize( + "op", + ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"], + ) + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0]) + def test_indexing_tensor(self, stack_dim, nt, op): + t, (x, y, z) = _tensorstack(stack_dim, nt) + t2 = getattr(t, op)(2) + torch.testing.assert_close(t2[0], getattr(x, op)(2)) + torch.testing.assert_close(t2[1], getattr(y, op)(2)) + torch.testing.assert_close(t2[2], getattr(z, op)(2)) + t2 = getattr(t, op)(torch.ones(5) * 2) + torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2)) + torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2)) + torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2)) + # check broadcasting + assert t2[0].shape == x.shape + v = torch.ones(17, 1, 1, 1, 5) * 2 + t2 = getattr(t, op)(v) + assert t2.shape == torch.Size([17, 3, 3, -1, 5]) + torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0])) + torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0])) + torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0])) + # check broadcasting + assert t2[:, 0].shape == torch.Size((17, *x.shape)) # # def test_permute(self): # w = torch.randint(10, (3, 5, 5)) From 3b9869b8aa3775f595cb348c9ddddd4ccea9b74c Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 2 Aug 2023 10:35:24 +0100 Subject: [PATCH 5/6] amend --- tensordict/tensorstack.py | 6 +++++- test/test_tensorstack.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 359f2207c..2b03ee815 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -55,6 +55,10 @@ def new_func(self, other): self_expand = self sd = self.stack_dim - self.ndim other = other.unbind(sd) + other = tuple( + _other.expand_as(_self) + for (_other, _self) in zip(other, self_expand.tensors) + ) other = LazyStackedTensors(other, stack_dim=sd).get_nestedtensor() else: self_expand = self @@ -62,7 +66,7 @@ def new_func(self, other): # print("result", getattr(torch.Tensor, func_name)(self.tensors, other)) return type(self)( getattr(torch.Tensor, func_name)(self_expand.tensors, other), - stack_dim=self.stack_dim, + stack_dim=self.stack_dim - self.ndim, ) if isinstance(other, (torch.Tensor,)): shape = torch.broadcast_shapes(other.shape, self._shape_no0) diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index bdbd71747..2c7949332 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -243,6 +243,9 @@ def test_permute(self, unbind, nt): @pytest.mark.parametrize("nt", [False, True]) @pytest.mark.parametrize("stack_dim", [0]) def test_indexing_tensor(self, stack_dim, nt, op): + if nt and op in ("__eq__", "__ne__", "__mod__"): + # not implemented + return t, (x, y, z) = _tensorstack(stack_dim, nt) t2 = getattr(t, op)(2) torch.testing.assert_close(t2[0], getattr(x, op)(2)) @@ -262,6 +265,7 @@ def test_indexing_tensor(self, stack_dim, nt, op): torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0])) # check broadcasting assert t2[:, 0].shape == torch.Size((17, *x.shape)) + # # def test_permute(self): # w = torch.randint(10, (3, 5, 5)) From 25c14da8b2bed2c347bdd4b0f3e95b5de9da22bf Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 7 Aug 2023 10:51:38 +0200 Subject: [PATCH 6/6] amend --- tensordict/tensorstack.py | 86 +++++++++++++++++++++++++++++++++++++++ test/test_tensorstack.py | 67 +++++++++++------------------- 2 files changed, 109 insertions(+), 44 deletions(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 2b03ee815..db77f4f1a 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -575,5 +575,91 @@ def permute(self, *permute_dims): stack_dim=new_stack_dim, ) + def split(self, split_size_or_sections, dim=0): + if dim < 0: + dim = self.ndim + dim + if dim < 0 or dim > self.ndim-1: + raise ValueError(f"split dimension isn't compatible with the tensor dimensions: dim={dim} and self.shape={self.shape}") + if self.shape[dim] == -1: + new_dim = dim if dim < self.stack_dim else dim-1 + out = [] + for t in self.tensors: + out.append(LazyStackedTensors(t.split(split_size_or_sections, dim=new_dim), stack_dim=self.stack_dim)) + return tuple(out) + if isinstance(split_size_or_sections, int): + split_size_or_sections = [split_size_or_sections] * (self.shape[dim] // split_size_or_sections) + res = self.shape[dim] % split_size_or_sections[0] + if res > 0: + split_size_or_sections += [res] + out = [] + i = 0 + for splits in split_size_or_sections: + idx = (slice(None),) * dim + (range(i, i+splits),) + i += splits + out.append(self[idx]) + return tuple(out) + + def numel(self): + if self._nested: + return self.tensors.numel() + return sum(t.numel() for t in self.tensors) + + def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + if dim is None: + tensor = self.reshape(-1) + return tensor.unique(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts) + + def reshape(self, *shape): + """Reshapes a tensor to the desired shape. + + Returns a tensor with the same data and number of elements as self + but with the specified shape. This method returns a view if shape is + compatible with the current shape. + See torch.Tensor.view() on when it is possible to return a view. + + with TensorStack, reshaping can only occur up until the heterogeneous + dimension (if any). The trailing dimensions (from the first heterogeneous + till the end) must match (unless the tensor is flatten with ``-1``). + Heterogeneous dimensions must be indicated with ``None``. + + """ + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = tuple(shape[0]) + if shape == (-1,): + return torch.cat([t.reshape(-1) for t in self.tensors], 0) + trailing_dims = self._trailing_dims() + trailing_shape = tuple(d if d >= 0 else None for d in shape[-len(trailing_dims):]) + if trailing_dims != trailing_shape: + raise ValueError(f"Trailing dimensions must match in reshape. Got {trailing_dims} and shape {trailing_dims}") + init_dim = shape[:-len(trailing_dims)] + if self.stack_dim > self.ndim - len(trailing_dims): + # simplest use case + new_stack_dim = self.stack_dim - self.ndim + trailing_dims_pop = [t for i, t in enumerate(trailing_dims) if i - len(trailing_dims) != new_stack_dim] + def trailing_shape(tensor): + return tuple(s for s in tensor.shape[-len(trailing_dims_pop):]) + return type(self)([t.reshape(*init_dim, *trailing_shape(t)) for t in self.tensors], stack_dim=new_stack_dim) + # TODO + # return type(self)( + # [t.reshape] + # ) + + def _trailing_dims(self): + trailing_dims = [] + for d in self.shape: + if d == -1 or len(trailing_dims): + trailing_dims.append(d) + return tuple(trailing_dims) + + def reshape_as(self): + ... + def view(self): + raise NotImplementedError( + "Viewing a TensorStack is not allowed. Use .reshape instead." + ) + + def view_as(self): + raise NotImplementedError("Viewing a TensorStack is not allowed. Use .reshape instead.") + def __repr__(self): return f"{self.__class__.__name__}({self.get_nestedtensor()})" diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index 2c7949332..c9a6416a7 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -216,26 +216,6 @@ def test_permute(self, unbind, nt): for v1, v2 in zip(t.unbind(unbind), stack.unbind(unbind)): assert (v1 == v2).all() - # def test_indexing_composite(self, _tensorstack): - # _, (x, y, z) = _tensorstack - # t = TensorStack.from_tensors([[x, y, z], [x, y, z]]) - # assert (t[0, 0] == x).all() - # assert (t[torch.tensor([0]), torch.tensor([0])] == x).all() - # assert (t[torch.tensor([0]), torch.tensor([1])] == y).all() - # assert (t[torch.tensor([0]), torch.tensor([2])] == z).all() - # assert (t[:, torch.tensor([0])] == x).all() - # assert (t[:, torch.tensor([1])] == y).all() - # assert (t[:, torch.tensor([2])] == z).all() - # assert ( - # t[torch.tensor([0]), torch.tensor([1, 2])] - # == TensorStack.from_tensors([y, z]) - # ).all() - # with pytest.raises(IndexError, match="Cannot index along"): - # assert ( - # t[..., torch.tensor([1, 2]), :, :, :] - # == TensorStack.from_tensors([y, z]) - # ).all() - # @pytest.mark.parametrize( "op", ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"], @@ -266,31 +246,30 @@ def test_indexing_tensor(self, stack_dim, nt, op): # check broadcasting assert t2[:, 0].shape == torch.Size((17, *x.shape)) - # - # def test_permute(self): - # w = torch.randint(10, (3, 5, 5)) - # x = torch.randint(10, (3, 4, 5)) - # y = torch.randint(10, (3, 5, 5)) - # z = torch.randint(10, (3, 4, 5)) - # ts = TensorStack.from_tensors([[w, x], [y, z]]) - # tst = ts.permute(1, 0, 2, 3, 4) - # assert (tst[0, 1] == ts[1, 0]).all() - # assert (tst[1, 0] == ts[0, 1]).all() - # assert (tst[1, 1] == ts[1, 1]).all() - # assert (tst[0, 0] == ts[0, 0]).all() - # - # def test_transpose(self): - # w = torch.randint(10, (3, 5, 5)) - # x = torch.randint(10, (3, 4, 5)) - # y = torch.randint(10, (3, 5, 5)) - # z = torch.randint(10, (3, 4, 5)) - # ts = TensorStack.from_tensors([[w, x], [y, z]]) - # tst = ts.transpose(1, 0) - # assert (tst[0, 1] == ts[1, 0]).all() - # assert (tst[1, 0] == ts[0, 1]).all() - # assert (tst[1, 1] == ts[1, 1]).all() - # assert (tst[0, 0] == ts[0, 0]).all() + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -3, -2, -1]) + @pytest.mark.parametrize("dim", [0, 1, 2, 3, -3, -2, -1]) + def test_split(self,stack_dim, nt, dim): + t, (x, y, z) = _tensorstack(stack_dim, nt) + tsplit = t.split(3, dim) + assert sum(ts.numel() for ts in tsplit) == t.numel() + uniques = set() + for ts in tsplit: + uniques = uniques.union(ts.unique().tolist()) + assert uniques == set(t.unique().tolist()) + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -3, -2, -1]) + def test_reshape(self, stack_dim, nt, dim): + ... + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -3, -2, -1]) + def test_unique(self, stack_dim, nt, dim): + ... + @pytest.mark.parametrize("nt", [False, True]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2, 3, -3, -2, -1]) + def test_view(self, stack_dim, nt, dim): + ... if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()