Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] dense_stack_tds #506

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ Utils
merge_tensordicts
pad
pad_sequence
dense_stack_tds
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import tensorclass
from tensordict.tensordict import (
dense_stack_tds,
is_batchedtensor,
is_memmap,
is_tensor_collection,
Expand Down Expand Up @@ -46,6 +47,7 @@
"pad",
"PersistentTensorDict",
"tensorclass",
"dense_stack_tds",
]

# from tensordict._pytree import *
66 changes: 66 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8646,6 +8646,72 @@ def make_tensordict(
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)


def dense_stack_tds(
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict,
dim: int = None,
) -> TensorDictBase:
"""Densely stack a list of :class:`tensordict.TensorDictBase` objects (or a :class:`tensordict.LazyStackedTensorDict`) given that they have the same structure.

This must be used when some of the :class:`tensordict.TensorDictBase` objects that need to be stacked
can have :class:`tensordict.LazyStackedTensorDict` among entries (or nested entries).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why "can have"? why not just "have"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

among their entries

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this method can be used even when no lazy stacks are involved

In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible.
Thus, this function provides an alternative for densely stacking the list provided.
matteobettini marked this conversation as resolved.
Show resolved Hide resolved

Args:
td_list (List of TensorDictBase or LazyStackedTensorDict): the tds to stack.
dim (int, optional): the dimension to stack them.
If td_list is a LazyStackedTensorDict, it will be retrieved automatically.

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict import dense_stack_tds
>>> from tensordict.tensordict import assert_allclose_td
>>> a = TensorDict({"a": torch.zeros(3)},[])
>>> b = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[])
>>> td_lazy = torch.stack([a,b], dim=0)
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
>>> td_lazy_clone = td_lazy.clone()
>>> td_stack = torch.stack([td_lazy,td_lazy_clone], dim=0)
>>> td_stack
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=0)
>>> dense_td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
1 ->
b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the appearance of b seems to be the only difference but it's still a lazy key
Why do we need 2 levels of nesting? This example isn't super clear to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need 2 levels of nesting because the function is used to stack tds that contain lazy stacks. so it is only useful when there are at least two levels of nesting. The change of the stack dim from 0 to 1 is the core difference here. "b" is just there to highlight that

# Note that this has pushed the stack_dim (0 -> 1) and revealed the exclusive keys.
>>> assert_allclose_td(dense_td_stack, dense_stack_tds([td_lazy,td_lazy_clone], dim=0))
# This shows it is the same to pass a list or a LazyStackedTensorDict

"""
if isinstance(td_list, LazyStackedTensorDict):
dim = td_list.stack_dim
td_list = td_list.tensordicts
elif dim is None:
raise ValueError(
"If a list of tensordicts is provided, stack_dim must not be None"
)
shape = list(td_list[0].shape)
shape.insert(dim, len(td_list))

out = td_list[0].unsqueeze(dim).expand(shape).clone()
return torch.stack(td_list, dim=dim, out=out)


def _set_max_batch_size(source: TensorDictBase, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
tensor_data = list(source.values())
Expand Down
41 changes: 41 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_CustomOpTensorDict,
_stack as stack_td,
assert_allclose_td,
dense_stack_tds,
is_tensor_collection,
make_tensordict,
pad,
Expand Down Expand Up @@ -5906,6 +5907,46 @@ def test_empty():
assert len(list(td_empty.get("b").keys())) == 1


@pytest.mark.parametrize(
"stack_dim",
[0, 1, 2, 3],
)
@pytest.mark.parametrize(
"nested_stack_dim",
[0, 1, 2],
)
def test_dense_stack_tds(stack_dim, nested_stack_dim):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see what we're testing specific to this method here. All the assertions would hold with a regular tensordict no?
Why aren't we testing anything specific to the keys we provide?

Copy link
Contributor Author

@matteobettini matteobettini Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the idea is that this function should behave exactly the same with regualr tds or lazy stacks so the fact that a regualr td passes is ok to me. the main thing we are testing is that, when stacking tds that contain lazy stacks, we can call this function to remove one level of lazyness and densify one stack_dim

batch_size = (5, 6)
a = TensorDict(
{"a": torch.zeros(*batch_size, 3)},
batch_size,
)
b = TensorDict(
{"a": torch.zeros(*batch_size, 4), "b": torch.zeros(*batch_size, 2)},
batch_size,
)
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
td_lazy = torch.stack([a, b], dim=nested_stack_dim)
td_lazy_clone = td_lazy.clone()
td_lazy_clone.apply_(lambda x: x + 1)

assert td_lazy.stack_dim == nested_stack_dim
td_stack = torch.stack([td_lazy, td_lazy_clone], dim=stack_dim)
assert td_stack.stack_dim == stack_dim

dense_td_stack = dense_stack_tds(td_stack)
assert assert_allclose_td(
dense_td_stack, dense_stack_tds([td_lazy, td_lazy_clone], dim=stack_dim)
)
for i in [0, 1]:
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
index = (slice(None),) * stack_dim + (i,)
assert (dense_td_stack[index] == i).all()

if stack_dim > nested_stack_dim:
assert dense_td_stack.stack_dim == nested_stack_dim
else:
assert dense_td_stack.stack_dim == nested_stack_dim + 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading