Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support propagation of in-place modifications to masked tensordict #132

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class TensorDictBase(Mapping, metaclass=abc.ABCMeta):
_lazy = False
_inplace_set = False
is_meta = False
_sources: Tuple[TensorDictBase, List[INDEX_TYPING]] = None

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
Expand Down Expand Up @@ -1976,6 +1977,11 @@ def __getitem__(self, idx: INDEX_TYPING) -> TensorDictBase:
def __setitem__(
self, index: INDEX_TYPING, value: Union[TensorDictBase, dict]
) -> None:
if self._sources:
source, idxx = self._sources
for idx in idxx:
source = source.get_sub_tensordict(idx)
source.__setitem__(index, value)
if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index):
index = convert_ellipsis_to_idx(index, self.batch_size)
if isinstance(index, list):
Expand Down Expand Up @@ -2440,6 +2446,10 @@ def _index_tensordict(self, idx: INDEX_TYPING):
self_copy._dict_meta = KeyDependentDefaultDict(self_copy._make_meta)
self_copy._batch_size = _getitem_batch_size(self_copy.batch_size, idx)
self_copy._device = self.device
if not self_copy._sources:
self_copy._sources = (self, [idx])
else:
self_copy._sources = self_copy._sources[0], self_copy._sources[1] + [idx]
return self_copy

def pin_memory(self) -> TensorDictBase:
Expand Down
10 changes: 10 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3436,6 +3436,16 @@ def test_shared_inheritance():
assert td0.is_shared()


def test_in_place_set_to_masked_tensordict():
# Related to https://github.com/pytorch/rl/issues/298
td = TensorDict({"a": torch.randn(3, 4, 2), "b": torch.randn(3, 4)}, [3, 4])

mask = torch.tensor([True, False, True])
x = torch.randn(2, 4, 2)
td[mask]["a"] = x
torch.testing.assert_allclose(td[mask]["a"], x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool but it's the simplest use case. Maybe let's try with a couple of other masks?



if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)