diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8f266a119..3990af874 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -10,7 +10,11 @@ from tensordict import PersistentTensorDict, tensorclass, TensorDict from tensordict.nn.params import TensorDictParams -from tensordict.tensordict import _stack as stack_td +from tensordict.tensordict import ( + _stack as stack_td, + is_tensor_collection, + LazyStackedTensorDict, +) def prod(sequence): @@ -234,3 +238,15 @@ def expand_list(list_of_tensors, *dims): td = TensorDict({str(i): tensor for i, tensor in enumerate(list_of_tensors)}, []) td = td.expand(*dims).contiguous() return [td[str(i)] for i in range(n)] + + +def decompose(td): + if isinstance(td, LazyStackedTensorDict): + for inner_td in td.tensordicts: + yield from decompose(inner_td) + else: + for v in td.values(): + if is_tensor_collection(v): + yield from decompose(v) + else: + yield v diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 3424a280b..626ad5692 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -11,6 +11,7 @@ import pytest import torch + try: import torchsnapshot @@ -27,7 +28,7 @@ except ImportError: _has_h5py = False -from _utils_internal import get_available_devices, prod, TestTensorDictsBase +from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict from tensordict.tensordict import ( @@ -295,7 +296,10 @@ def test_cat_td(device): "key3": {"key4": torch.zeros(4, 15, 10, device=device)}, } td_out = TensorDict(batch_size=(4, 15), source=d, device=device) + data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)} torch.cat([td1, td2], 1, out=td_out) + data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)} + assert data_ptr_set_before == data_ptr_set_after assert td_out.batch_size == torch.Size([4, 15]) assert (td_out["key1"] != 0).all() assert (td_out["key2"] != 0).all() @@ -2016,7 +2020,10 @@ def test_stack_onto(self, td_name, device, tmpdir): with pytest.raises(RuntimeError, match="out.batch_size and stacked"): torch.stack([td0, td1], 0, out=td_out) return + data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)} torch.stack([td0, td1], 1, out=td_out) + data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)} + assert data_ptr_set_before == data_ptr_set_after assert (td_stack == td_out).all() @pytest.mark.filterwarnings("error") @@ -2045,7 +2052,11 @@ def test_stack_tds_on_subclass(self, td_name, device): with pytest.raises(RuntimeError, match="arguments don't support automatic"): torch.stack(tds_list, 0, out=td) return + data_ptr_set_before = {val.data_ptr() for val in decompose(td)} + stacked_td = torch.stack(tds_list, 0, out=td) + data_ptr_set_after = {val.data_ptr() for val in decompose(td)} + assert data_ptr_set_before == data_ptr_set_after assert stacked_td.batch_size == td.batch_size assert stacked_td is td for key in ("a", "b", "c"): @@ -2061,7 +2072,10 @@ def test_stack_subclasses_on_td(self, td_name, device): with pytest.raises(RuntimeError, match="arguments don't support automatic"): torch.stack(tds_list, 0, out=td) return + data_ptr_set_before = {val.data_ptr() for val in decompose(td)} stacked_td = stack_td(tds_list, 0, out=td) + data_ptr_set_after = {val.data_ptr() for val in decompose(td)} + assert data_ptr_set_before == data_ptr_set_after assert stacked_td.batch_size == td.batch_size for key in ("a", "b", "c"): assert (stacked_td[key] == td[key]).all() @@ -4274,17 +4288,6 @@ def nested_lazy_het_td(batch_size): obs = obs.expand(batch_size).clone() return obs - def decompose(self, td): - if isinstance(td, LazyStackedTensorDict): - for inner_td in td.tensordicts: - yield from self.decompose(inner_td) - else: - for v in td.values(): - if is_tensor_collection(v): - yield from self.decompose(v) - else: - yield v - @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)]) @pytest.mark.parametrize("cat_dim", [0, 1, 2]) def test_cat_lazy_stack(self, batch_size, cat_dim): @@ -4296,9 +4299,9 @@ def test_cat_lazy_stack(self, batch_size, cat_dim): assert assert_allclose_td(res, td_lazy) assert res is not td_lazy td_lazy_clone = td_lazy.clone() - data_ptr_set_before = {val.data_ptr() for val in self.decompose(td_lazy)} + data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy)} res = torch.cat([td_lazy_clone], dim=cat_dim, out=td_lazy) - data_ptr_set_after = {val.data_ptr() for val in self.decompose(td_lazy)} + data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy)} assert data_ptr_set_after == data_ptr_set_before assert res is td_lazy assert assert_allclose_td(res, td_lazy_clone) @@ -4326,13 +4329,9 @@ def test_cat_lazy_stack(self, batch_size, cat_dim): batch_size = list(batch_size) batch_size[cat_dim] *= 2 td_lazy_dest = self.nested_lazy_het_td(batch_size)["lazy"] - data_ptr_set_before = { - val.data_ptr() for val in self.decompose(td_lazy_dest) - } + data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy_dest)} res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim, out=td_lazy_dest) - data_ptr_set_after = { - val.data_ptr() for val in self.decompose(td_lazy_dest) - } + data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy_dest)} assert data_ptr_set_after == data_ptr_set_before assert res is td_lazy_dest index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) @@ -4370,10 +4369,14 @@ def dense_stack_tds_v1(self, td_list, stack_dim: int) -> TensorDictBase: def dense_stack_tds_v2(self, td_list, stack_dim: int) -> TensorDictBase: shape = list(td_list[0].shape) shape.insert(stack_dim, len(td_list)) - out = td_list[0].unsqueeze(stack_dim).expand(shape).clone() - return torch.stack(td_list, dim=stack_dim, out=out) + data_ptr_set_before = {val.data_ptr() for val in decompose(out)} + res = torch.stack(td_list, dim=stack_dim, out=out) + data_ptr_set_after = {val.data_ptr() for val in decompose(out)} + assert data_ptr_set_before == data_ptr_set_after + + return res @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)]) @pytest.mark.parametrize("stack_dim", [0, 1, 2])