Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TensorStack (2) #505

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 126 additions & 10 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import collections
import functools
import numbers
import operator
import os
import re
import textwrap
Expand Down Expand Up @@ -2055,6 +2056,93 @@ def __ne__(self, other: object) -> TensorDictBase:
# def __hash__(self):
# ...

def __and__(self, other):
"""Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set.

Returns:
a new TensorDict instance with all tensors are boolean
tensors of the same shape as the original tensors.

"""
if is_tensorclass(other):
return other & self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
d = {}
for key, item1 in self.items():
d[key] = item1 & other.get(key)
return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
if isinstance(other, (numbers.Number, Tensor)):
return TensorDict(
{key: value & other for key, value in self.items()},
self.batch_size,
device=self.device,
)
raise NotImplementedError(
f"Cannot compare objects of type {type(self)} and {type(other)}"
)

def __or__(self, other):
"""Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set.

Returns:
a new TensorDict instance with all tensors are boolean
tensors of the same shape as the original tensors.

"""
if is_tensorclass(other):
return other | self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
d = {}
for key, item1 in self.items():
d[key] = item1 | other.get(key)
return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
if isinstance(other, (numbers.Number, Tensor)):
return TensorDict(
{key: value | other for key, value in self.items()},
self.batch_size,
device=self.device,
)
raise NotImplementedError(
f"Cannot compare objects of type {type(self)} and {type(other)}"
)

def __xor__(self, other):
"""Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set.

Returns:
a new TensorDict instance with all tensors are boolean
tensors of the same shape as the original tensors.

"""
if is_tensorclass(other):
return other ^ self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
d = {}
for key, item1 in self.items():
d[key] = item1 ^ other.get(key)
return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
if isinstance(other, (numbers.Number, Tensor)):
return TensorDict(
{key: value ^ other for key, value in self.items()},
self.batch_size,
device=self.device,
)
raise NotImplementedError(
f"Cannot compare objects of type {type(self)} and {type(other)}"
)

def __eq__(self, other: object) -> TensorDictBase:
"""Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set.

Expand Down Expand Up @@ -7075,14 +7163,28 @@ def all(self, dim: int = None) -> bool | TensorDictBase:
"smaller than tensordict.batch_dims"
)
if dim is not None:
# TODO: we need to adapt this to LazyStackedTensorDict too
if dim < 0:
dim = self.batch_dims + dim
return TensorDict(
source={key: value.all(dim=dim) for key, value in self.items()},
batch_size=[b for i, b in enumerate(self.batch_size) if i != dim],
device=self.device,
if dim > self.stack_dim:
dim = dim - 1
new_stack_dim = self.stack_dim
elif dim == self.stack_dim:
if len(self.tensordicts) == 1:
return self.tensordicts[0].apply(lambda x: x.bool())

val = functools.reduce(
operator.and_,
[td.apply(lambda x: x.bool()) for td in self.tensordicts],
)
return val
else:
new_stack_dim = self.stack_dim - 1

out = LazyStackedTensorDict(
*[td.all(dim) for td in self.tensordicts], stack_dim=new_stack_dim
)
out._td_dim_name = self._td_dim_name
return out
return all(value.all() for value in self.tensordicts)

def any(self, dim: int = None) -> bool | TensorDictBase:
Expand All @@ -7092,14 +7194,28 @@ def any(self, dim: int = None) -> bool | TensorDictBase:
"smaller than tensordict.batch_dims"
)
if dim is not None:
# TODO: we need to adapt this to LazyStackedTensorDict too
if dim < 0:
dim = self.batch_dims + dim
return TensorDict(
source={key: value.any(dim=dim) for key, value in self.items()},
batch_size=[b for i, b in enumerate(self.batch_size) if i != dim],
device=self.device,
if dim > self.stack_dim:
dim = dim - 1
new_stack_dim = self.stack_dim
elif dim == self.stack_dim:
if len(self.tensordicts) == 1:
return self.tensordicts[0].apply(lambda x: x.bool())

val = functools.reduce(
operator.or_,
[td.apply(lambda x: x.bool()) for td in self.tensordicts],
)
return val
else:
new_stack_dim = self.stack_dim - 1

out = LazyStackedTensorDict(
*[td.any(dim) for td in self.tensordicts], stack_dim=new_stack_dim
)
out._td_dim_name = self._td_dim_name
return out
return any(value.any() for value in self.tensordicts)

def _send(self, dst: int, _tag: int = -1, pseudo_rand: bool = False) -> int:
Expand Down
Loading
Loading