Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Jul 28, 2023
1 parent d7fce9d commit a25401e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
18 changes: 17 additions & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from tensordict import PersistentTensorDict, tensorclass, TensorDict
from tensordict.nn.params import TensorDictParams
from tensordict.tensordict import _stack as stack_td
from tensordict.tensordict import (
_stack as stack_td,
is_tensor_collection,
LazyStackedTensorDict,
)


def prod(sequence):
Expand Down Expand Up @@ -234,3 +238,15 @@ def expand_list(list_of_tensors, *dims):
td = TensorDict({str(i): tensor for i, tensor in enumerate(list_of_tensors)}, [])
td = td.expand(*dims).contiguous()
return [td[str(i)] for i in range(n)]


def decompose(td):
if isinstance(td, LazyStackedTensorDict):
for inner_td in td.tensordicts:
yield from decompose(inner_td)
else:
for v in td.values():
if is_tensor_collection(v):
yield from decompose(v)
else:
yield v
47 changes: 25 additions & 22 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import torch


try:
import torchsnapshot

Expand All @@ -27,7 +28,7 @@
except ImportError:
_has_h5py = False

from _utils_internal import get_available_devices, prod, TestTensorDictsBase
from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase

from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict
from tensordict.tensordict import (
Expand Down Expand Up @@ -295,7 +296,10 @@ def test_cat_td(device):
"key3": {"key4": torch.zeros(4, 15, 10, device=device)},
}
td_out = TensorDict(batch_size=(4, 15), source=d, device=device)
data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)}
torch.cat([td1, td2], 1, out=td_out)
data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)}
assert data_ptr_set_before == data_ptr_set_after
assert td_out.batch_size == torch.Size([4, 15])
assert (td_out["key1"] != 0).all()
assert (td_out["key2"] != 0).all()
Expand Down Expand Up @@ -2016,7 +2020,10 @@ def test_stack_onto(self, td_name, device, tmpdir):
with pytest.raises(RuntimeError, match="out.batch_size and stacked"):
torch.stack([td0, td1], 0, out=td_out)
return
data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)}
torch.stack([td0, td1], 1, out=td_out)
data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)}
assert data_ptr_set_before == data_ptr_set_after
assert (td_stack == td_out).all()

@pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -2045,7 +2052,11 @@ def test_stack_tds_on_subclass(self, td_name, device):
with pytest.raises(RuntimeError, match="arguments don't support automatic"):
torch.stack(tds_list, 0, out=td)
return
data_ptr_set_before = {val.data_ptr() for val in decompose(td)}

stacked_td = torch.stack(tds_list, 0, out=td)
data_ptr_set_after = {val.data_ptr() for val in decompose(td)}
assert data_ptr_set_before == data_ptr_set_after
assert stacked_td.batch_size == td.batch_size
assert stacked_td is td
for key in ("a", "b", "c"):
Expand All @@ -2061,7 +2072,10 @@ def test_stack_subclasses_on_td(self, td_name, device):
with pytest.raises(RuntimeError, match="arguments don't support automatic"):
torch.stack(tds_list, 0, out=td)
return
data_ptr_set_before = {val.data_ptr() for val in decompose(td)}
stacked_td = stack_td(tds_list, 0, out=td)
data_ptr_set_after = {val.data_ptr() for val in decompose(td)}
assert data_ptr_set_before == data_ptr_set_after
assert stacked_td.batch_size == td.batch_size
for key in ("a", "b", "c"):
assert (stacked_td[key] == td[key]).all()
Expand Down Expand Up @@ -4274,17 +4288,6 @@ def nested_lazy_het_td(batch_size):
obs = obs.expand(batch_size).clone()
return obs

def decompose(self, td):
if isinstance(td, LazyStackedTensorDict):
for inner_td in td.tensordicts:
yield from self.decompose(inner_td)
else:
for v in td.values():
if is_tensor_collection(v):
yield from self.decompose(v)
else:
yield v

@pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
@pytest.mark.parametrize("cat_dim", [0, 1, 2])
def test_cat_lazy_stack(self, batch_size, cat_dim):
Expand All @@ -4296,9 +4299,9 @@ def test_cat_lazy_stack(self, batch_size, cat_dim):
assert assert_allclose_td(res, td_lazy)
assert res is not td_lazy
td_lazy_clone = td_lazy.clone()
data_ptr_set_before = {val.data_ptr() for val in self.decompose(td_lazy)}
data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy)}
res = torch.cat([td_lazy_clone], dim=cat_dim, out=td_lazy)
data_ptr_set_after = {val.data_ptr() for val in self.decompose(td_lazy)}
data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy)}
assert data_ptr_set_after == data_ptr_set_before
assert res is td_lazy
assert assert_allclose_td(res, td_lazy_clone)
Expand Down Expand Up @@ -4326,13 +4329,9 @@ def test_cat_lazy_stack(self, batch_size, cat_dim):
batch_size = list(batch_size)
batch_size[cat_dim] *= 2
td_lazy_dest = self.nested_lazy_het_td(batch_size)["lazy"]
data_ptr_set_before = {
val.data_ptr() for val in self.decompose(td_lazy_dest)
}
data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy_dest)}
res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim, out=td_lazy_dest)
data_ptr_set_after = {
val.data_ptr() for val in self.decompose(td_lazy_dest)
}
data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy_dest)}
assert data_ptr_set_after == data_ptr_set_before
assert res is td_lazy_dest
index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),)
Expand Down Expand Up @@ -4370,10 +4369,14 @@ def dense_stack_tds_v1(self, td_list, stack_dim: int) -> TensorDictBase:
def dense_stack_tds_v2(self, td_list, stack_dim: int) -> TensorDictBase:
shape = list(td_list[0].shape)
shape.insert(stack_dim, len(td_list))

out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()

return torch.stack(td_list, dim=stack_dim, out=out)
data_ptr_set_before = {val.data_ptr() for val in decompose(out)}
res = torch.stack(td_list, dim=stack_dim, out=out)
data_ptr_set_after = {val.data_ptr() for val in decompose(out)}
assert data_ptr_set_before == data_ptr_set_after

return res

@pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
@pytest.mark.parametrize("stack_dim", [0, 1, 2])
Expand Down

0 comments on commit a25401e

Please sign in to comment.