Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tensordict.pad_sequence with a tensorclass returns TensorDict #771

Closed
3 tasks done
egaznep opened this issue May 3, 2024 · 5 comments
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@egaznep
Copy link
Contributor

egaznep commented May 3, 2024

Describe the bug

The function tensordict.pad_sequence is great for collating individual datapoints of varying length. And tensorclasses are good for typing hints. However, these two do not work together nicely because pad_sequence always returns a TensorDict and never the given class.

To Reproduce

import torch
import tensordict
@tensordict.tensorclass
class DataPoint:
    a: torch.Tensor
    b: torch.Tensor

s1 = DataPoint(a=torch.ones(10, 1), b=torch.ones(10, 1), batch_size=[])
s2 = DataPoint(a=torch.ones(20, 1), b=torch.ones(30, 1), batch_size=[])
tensordict.pad_sequence([s1, s2], batch_first=True)

Output: 
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 30, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Expected behavior

Ideally, return should be of the type DataPoint.

Reason and Possible fixes

In pad_sequence,

    if out is None:
        out = TensorDict(
            {}, batch_size=torch.Size(shape), device=device, _run_checks=False
        )

constructs the return object. Either a modification here could solve it, or right before returning the TensorDict, we can use the tensorclass._from_tensordict constructor (in our case, DataPoint._from_tensordict) to comply with the original data type.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egaznep egaznep added the bug Something isn't working label May 3, 2024
@vmoens
Copy link
Contributor

vmoens commented May 3, 2024

Yeah we should call empty there! Good catch

@vmoens
Copy link
Contributor

vmoens commented May 3, 2024

Edit: On nightly I get a Datapoint object, which version are you using?

@egaznep
Copy link
Contributor Author

egaznep commented May 3, 2024

Edit: On nightly I get a Datapoint object, which version are you using?

Oh really? I had
tensordict 0.1.2+6de2db2.

as available on conda-forge https://anaconda.org/conda-forge/tensordict

There's 0.4.0 out?!

@vmoens
Copy link
Contributor

vmoens commented May 3, 2024

Yeaaah not very good at upgrading conda :/ Let me do that!

@egaznep
Copy link
Contributor Author

egaznep commented May 3, 2024

Thank you! I guess this issue could be closed :)

@vmoens vmoens closed this as completed May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants