Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 2, 2023
1 parent 9e8c325 commit 8708696
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 26 deletions.
26 changes: 24 additions & 2 deletions tensordict/tensorstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,23 @@ def _elementiwse_broadcast(func):

def new_func(self, other):
if self._nested:
if isinstance(other, torch.Tensor) and not other.is_nested:
shape = torch.broadcast_shapes(other.shape, self._shape_no0)
if shape != other.shape:
other = other.expand(shape)
if shape != self._shape_no0:
self_expand = self.expand(shape).as_nestedtensor()
else:
self_expand = self
sd = self.stack_dim - self.ndim
other = other.unbind(sd)
other = LazyStackedTensors(other, stack_dim=sd).get_nestedtensor()
else:
self_expand = self
# print("op", func_name, "\nt", self.tensors, "\nother", other)
# print("result", getattr(torch.Tensor, func_name)(self.tensors, other))
return type(self)(
getattr(torch.Tensor, func_name)(self.tensors, other),
getattr(torch.Tensor, func_name)(self_expand.tensors, other),
stack_dim=self.stack_dim,
)
if isinstance(other, (torch.Tensor,)):
Expand All @@ -58,6 +73,7 @@ def new_func(self, other):
else:
self_expand = self
other = other.unbind(self_expand.stack_dim)
new_stack_dim = self.stack_dim + len(shape) - self.ndim
elif isinstance(other, (LazyStackedTensors,)):
shape = torch.broadcast_shapes(other._shape_no0, self._shape_no0)
if shape != other._shape_no0:
Expand All @@ -67,15 +83,17 @@ def new_func(self, other):
else:
self_expand = self
other = other.unbind(self_expand.stack_dim)
new_stack_dim = self.stack_dim + len(shape) - self.ndim
else:
self_expand = self
other = (other,) * self.n
new_stack_dim = self.stack_dim
return type(self)(
[
getattr(torch.Tensor, func_name)(t, _other)
for t, _other in zip(self_expand.tensors, other)
],
self.stack_dim,
stack_dim=new_stack_dim,
)

return new_func
Expand Down Expand Up @@ -348,6 +366,10 @@ def __eq__(self, other):
def __ne__(self, other):
...

@_elementiwse_broadcast
def __mod__(self, other):
...

@property
def n(self):
return self.shape[self.stack_dim]
Expand Down
50 changes: 26 additions & 24 deletions test/test_tensorstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,30 +236,32 @@ def test_permute(self, unbind, nt):
# == TensorStack.from_tensors([y, z])
# ).all()
#
# @pytest.mark.parametrize(
# "op",
# ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"],
# )
# def test_elementwise(self, _tensorstack, op):
# t, (x, y, z) = _tensorstack
# t2 = getattr(t, op)(2)
# torch.testing.assert_close(t2[0], getattr(x, op)(2))
# torch.testing.assert_close(t2[1], getattr(y, op)(2))
# torch.testing.assert_close(t2[2], getattr(z, op)(2))
# t2 = getattr(t, op)(torch.ones(5) * 2)
# torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2))
# torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2))
# torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2))
# # check broadcasting
# assert t2[0].shape == x.shape
# v = torch.ones(2, 1, 1, 1, 5) * 2
# t2 = getattr(t, op)(v)
# assert t2.shape == torch.Size([2, 3, 3, -1, 5])
# torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0]))
# torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0]))
# torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0]))
# # check broadcasting
# assert t2[:, 0].shape == torch.Size((2, *x.shape))
@pytest.mark.parametrize(
"op",
["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"],
)
@pytest.mark.parametrize("nt", [False, True])
@pytest.mark.parametrize("stack_dim", [0])
def test_indexing_tensor(self, stack_dim, nt, op):
t, (x, y, z) = _tensorstack(stack_dim, nt)
t2 = getattr(t, op)(2)
torch.testing.assert_close(t2[0], getattr(x, op)(2))
torch.testing.assert_close(t2[1], getattr(y, op)(2))
torch.testing.assert_close(t2[2], getattr(z, op)(2))
t2 = getattr(t, op)(torch.ones(5) * 2)
torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2))
torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2))
torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2))
# check broadcasting
assert t2[0].shape == x.shape
v = torch.ones(17, 1, 1, 1, 5) * 2
t2 = getattr(t, op)(v)
assert t2.shape == torch.Size([17, 3, 3, -1, 5])
torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0]))
torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0]))
torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0]))
# check broadcasting
assert t2[:, 0].shape == torch.Size((17, *x.shape))
#
# def test_permute(self):
# w = torch.randint(10, (3, 5, 5))
Expand Down

0 comments on commit 8708696

Please sign in to comment.