-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] dense_stack_tds
#506
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,3 +66,4 @@ Utils | |
merge_tensordicts | ||
pad | ||
pad_sequence | ||
dense_stack_tds |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8646,6 +8646,72 @@ def make_tensordict( | |
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device) | ||
|
||
|
||
def dense_stack_tds( | ||
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict, | ||
dim: int = None, | ||
) -> TensorDictBase: | ||
"""Densely stack a list of :class:`tensordict.TensorDictBase` objects (or a :class:`tensordict.LazyStackedTensorDict`) given that they have the same structure. | ||
|
||
This must be used when some of the :class:`tensordict.TensorDictBase` objects that need to be stacked | ||
can have :class:`tensordict.LazyStackedTensorDict` among entries (or nested entries). | ||
In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible. | ||
Thus, this function provides an alternative for densely stacking the list provided. | ||
matteobettini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
td_list (List of TensorDictBase or LazyStackedTensorDict): the tds to stack. | ||
dim (int, optional): the dimension to stack them. | ||
If td_list is a LazyStackedTensorDict, it will be retrieved automatically. | ||
|
||
Examples: | ||
>>> import torch | ||
>>> from tensordict import TensorDict | ||
>>> from tensordict import dense_stack_tds | ||
>>> from tensordict.tensordict import assert_allclose_td | ||
>>> a = TensorDict({"a": torch.zeros(3)},[]) | ||
>>> b = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[]) | ||
>>> td_lazy = torch.stack([a,b], dim=0) | ||
matteobettini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> td_lazy_clone = td_lazy.clone() | ||
>>> td_stack = torch.stack([td_lazy,td_lazy_clone], dim=0) | ||
>>> td_stack | ||
LazyStackedTensorDict( | ||
fields={ | ||
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, | ||
exclusive_fields={ | ||
}, | ||
batch_size=torch.Size([2, 2]), | ||
device=None, | ||
is_shared=False, | ||
stack_dim=0) | ||
>>> dense_td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim | ||
LazyStackedTensorDict( | ||
fields={ | ||
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, | ||
exclusive_fields={ | ||
1 -> | ||
b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, | ||
batch_size=torch.Size([2, 2]), | ||
device=None, | ||
is_shared=False, | ||
stack_dim=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the appearance of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need 2 levels of nesting because the function is used to stack tds that contain lazy stacks. so it is only useful when there are at least two levels of nesting. The change of the stack dim from 0 to 1 is the core difference here. "b" is just there to highlight that |
||
# Note that this has pushed the stack_dim (0 -> 1) and revealed the exclusive keys. | ||
>>> assert_allclose_td(dense_td_stack, dense_stack_tds([td_lazy,td_lazy_clone], dim=0)) | ||
# This shows it is the same to pass a list or a LazyStackedTensorDict | ||
|
||
""" | ||
if isinstance(td_list, LazyStackedTensorDict): | ||
dim = td_list.stack_dim | ||
td_list = td_list.tensordicts | ||
elif dim is None: | ||
raise ValueError( | ||
"If a list of tensordicts is provided, stack_dim must not be None" | ||
) | ||
shape = list(td_list[0].shape) | ||
shape.insert(dim, len(td_list)) | ||
|
||
out = td_list[0].unsqueeze(dim).expand(shape).clone() | ||
return torch.stack(td_list, dim=dim, out=out) | ||
|
||
|
||
def _set_max_batch_size(source: TensorDictBase, batch_dims=None): | ||
"""Updates a tensordict with its maximium batch size.""" | ||
tensor_data = list(source.values()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
_CustomOpTensorDict, | ||
_stack as stack_td, | ||
assert_allclose_td, | ||
dense_stack_tds, | ||
is_tensor_collection, | ||
make_tensordict, | ||
pad, | ||
|
@@ -5906,6 +5907,46 @@ def test_empty(): | |
assert len(list(td_empty.get("b").keys())) == 1 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"stack_dim", | ||
[0, 1, 2, 3], | ||
) | ||
@pytest.mark.parametrize( | ||
"nested_stack_dim", | ||
[0, 1, 2], | ||
) | ||
def test_dense_stack_tds(stack_dim, nested_stack_dim): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see what we're testing specific to this method here. All the assertions would hold with a regular tensordict no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the idea is that this function should behave exactly the same with regualr tds or lazy stacks so the fact that a regualr td passes is ok to me. the main thing we are testing is that, when stacking tds that contain lazy stacks, we can call this function to remove one level of lazyness and densify one stack_dim |
||
batch_size = (5, 6) | ||
a = TensorDict( | ||
{"a": torch.zeros(*batch_size, 3)}, | ||
batch_size, | ||
) | ||
b = TensorDict( | ||
{"a": torch.zeros(*batch_size, 4), "b": torch.zeros(*batch_size, 2)}, | ||
batch_size, | ||
) | ||
matteobettini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
td_lazy = torch.stack([a, b], dim=nested_stack_dim) | ||
td_lazy_clone = td_lazy.clone() | ||
td_lazy_clone.apply_(lambda x: x + 1) | ||
|
||
assert td_lazy.stack_dim == nested_stack_dim | ||
td_stack = torch.stack([td_lazy, td_lazy_clone], dim=stack_dim) | ||
assert td_stack.stack_dim == stack_dim | ||
|
||
dense_td_stack = dense_stack_tds(td_stack) | ||
assert assert_allclose_td( | ||
dense_td_stack, dense_stack_tds([td_lazy, td_lazy_clone], dim=stack_dim) | ||
) | ||
for i in [0, 1]: | ||
matteobettini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
index = (slice(None),) * stack_dim + (i,) | ||
assert (dense_td_stack[index] == i).all() | ||
|
||
if stack_dim > nested_stack_dim: | ||
assert dense_td_stack.stack_dim == nested_stack_dim | ||
else: | ||
assert dense_td_stack.stack_dim == nested_stack_dim + 1 | ||
|
||
|
||
if __name__ == "__main__": | ||
args, unknown = argparse.ArgumentParser().parse_known_args() | ||
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why "can have"? why not just "have"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
among their entries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because this method can be used even when no lazy stacks are involved