Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Aug 2, 2023
1 parent a7be2f4 commit 7463384
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ Utils
merge_tensordicts
pad
pad_sequence
dense_stack_tds
66 changes: 66 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8646,6 +8646,72 @@ def make_tensordict(
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)


def dense_stack_tds(
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict,
dim: int = None,
) -> TensorDictBase:
"""Densely stack a list of :class:`tensordict.TensorDictBase` objects (or a :class:`tensordict.LazyStackedTensorDict`) given that they have the same structure.
This must be used when some of the :class:`tensordict.TensorDictBase` objects that need to be stacked
can have :class:`tensordict.LazyStackedTensorDict` among entries (or nested entries).
In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible.
Thus, this function provides an alternative for densely stacking the list provided.
Args:
td_list (List of TensorDictBase or LazyStackedTensorDict): the tds to stack.
dim (int, optional): the dimension to stack them.
If td_list is a LazyStackedTensorDict, it will be retrieved automatically.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict import dense_stack_tds
>>> from tensordict.tensordict import assert_allclose_td
>>> a = TensorDict({"a": torch.zeros(3)},[])
>>> b = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[])
>>> td_lazy = torch.stack([a,b], dim=0)
>>> td_lazy_clone = td_lazy.clone()
>>> td_stack = torch.stack([td_lazy,td_lazy_clone], dim=0)
>>> td_stack
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=0)
>>> dense_td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
1 ->
b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=1)
# Note that this has pushed the stack_dim (0 -> 1) and revealed the exclusive keys.
>>> assert_allclose_td(dense_td_stack, dense_stack_tds([td_lazy,td_lazy_clone], dim=0))
# This shows it is the same to pass a list or a LazyStackedTensorDict
"""
if isinstance(td_list, LazyStackedTensorDict):
dim = td_list.stack_dim
td_list = td_list.tensordicts
elif dim is None:
raise ValueError(
"If a list of tensordicts is provided, stack_dim must not be None"
)
shape = list(td_list[0].shape)
shape.insert(dim, len(td_list))

out = td_list[0].unsqueeze(dim).expand(shape).clone()
return torch.stack(td_list, dim=dim, out=out)


def _set_max_batch_size(source: TensorDictBase, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
tensor_data = list(source.values())
Expand Down
41 changes: 41 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_CustomOpTensorDict,
_stack as stack_td,
assert_allclose_td,
dense_stack_tds,
is_tensor_collection,
make_tensordict,
pad,
Expand Down Expand Up @@ -5906,6 +5907,46 @@ def test_empty():
assert len(list(td_empty.get("b").keys())) == 1


@pytest.mark.parametrize(
"stack_dim",
[0, 1, 2, 3],
)
@pytest.mark.parametrize(
"nested_stack_dim",
[0, 1, 2],
)
def test_dense_stack_tds(stack_dim, nested_stack_dim):
batch_size = (5, 6)
a = TensorDict(
{"a": torch.zeros(*batch_size, 3)},
batch_size,
)
b = TensorDict(
{"a": torch.zeros(*batch_size, 4), "b": torch.zeros(*batch_size, 2)},
batch_size,
)
td_lazy = torch.stack([a, b], dim=nested_stack_dim)
td_lazy_clone = td_lazy.clone()
td_lazy_clone.apply_(lambda x: x + 1)

assert td_lazy.stack_dim == nested_stack_dim
td_stack = torch.stack([td_lazy, td_lazy_clone], dim=stack_dim)
assert td_stack.stack_dim == stack_dim

dense_td_stack = dense_stack_tds(td_stack)
assert assert_allclose_td(
dense_td_stack, dense_stack_tds([td_lazy, td_lazy_clone], dim=stack_dim)
)
for i in [0, 1]:
index = (slice(None),) * stack_dim + (i,)
assert (dense_td_stack[index] == i).all()

if stack_dim > nested_stack_dim:
assert dense_td_stack.stack_dim == nested_stack_dim
else:
assert dense_td_stack.stack_dim == nested_stack_dim + 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 7463384

Please sign in to comment.