diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 39c3f6018..d8ee8c1da 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -253,6 +253,7 @@ class TensorDictBase(Mapping, metaclass=abc.ABCMeta): _lazy = False _inplace_set = False is_meta = False + _sources: Tuple[TensorDictBase, List[INDEX_TYPING]] = None def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -1976,6 +1977,11 @@ def __getitem__(self, idx: INDEX_TYPING) -> TensorDictBase: def __setitem__( self, index: INDEX_TYPING, value: Union[TensorDictBase, dict] ) -> None: + if self._sources: + source, idxx = self._sources + for idx in idxx: + source = source.get_sub_tensordict(idx) + source.__setitem__(index, value) if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) if isinstance(index, list): @@ -2440,6 +2446,10 @@ def _index_tensordict(self, idx: INDEX_TYPING): self_copy._dict_meta = KeyDependentDefaultDict(self_copy._make_meta) self_copy._batch_size = _getitem_batch_size(self_copy.batch_size, idx) self_copy._device = self.device + if not self_copy._sources: + self_copy._sources = (self, [idx]) + else: + self_copy._sources = self_copy._sources[0], self_copy._sources[1] + [idx] return self_copy def pin_memory(self) -> TensorDictBase: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7ef724f39..bd881ccf5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3436,6 +3436,16 @@ def test_shared_inheritance(): assert td0.is_shared() +def test_in_place_set_to_masked_tensordict(): + # Related to https://github.com/pytorch/rl/issues/298 + td = TensorDict({"a": torch.randn(3, 4, 2), "b": torch.randn(3, 4)}, [3, 4]) + + mask = torch.tensor([True, False, True]) + x = torch.randn(2, 4, 2) + td[mask]["a"] = x + torch.testing.assert_allclose(td[mask]["a"], x) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)