Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] torch.stack is not implemented for NonTensorStack and NonTensorData #836

Closed
3 tasks done
jkrude opened this issue Jun 25, 2024 · 1 comment
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@jkrude
Copy link

jkrude commented Jun 25, 2024

Describe the bug

This is a follow-up to #831 going into more detail about torch.stack.
NonTensorData and NonTensorStack can't be stacked, even though they hold the data in compatible shapes. This is especially unpredictable due to the behavior of NonTensorData._stack_non_tensor which might return a NonTensorData if all elements are equal.

To Reproduce

Both data and stack hold strings of batch_size torch.Size([2]).
As torch.stack (implemented in _troch_func) calls NonTensorData._stack_non_tensor the data variable becomes a NonTensorData instead of a NonTensorStack as all elements are equal.

data = torch.stack(
        [
            NonTensorData("a"),
            NonTensorData("a"),
        ]
    )
    stack = NonTensorStack(
        *(NonTensorData("b"), NonTensorData("b")),
    )
    assert torch.stack([data, stack], dim=1).batch_size == (2,2)

The last line will raise a NotImplemented error in torch_function as NonTensorData fails the check issubclass(t, (Tensor, cls, TensorDictBase))

E       TypeError: Multiple dispatch failed for 'torch.stack'; all __torch_function__ handlers returned NotImplemented:
E       
E         - tensor subclass <class 'tensordict.tensorclass.NonTensorData'>
E         - tensor subclass <class 'tensordict.tensorclass.NonTensorStack'>
E       
E       For more information, try re-running with TORCH_LOGS=not_implemented

Expected behavior

torch.cat returns a (2,2) NonTensorStack in which both data and stack are stacked together.

Reasons

The behavior of NonTensorData._stack_non_tensor should be transparent to all other functionality, especially torch.stack should work with NonTensorData and NonTensorStack if both are of compatible batch sizes.

System info

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0
Installed using pip.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@jkrude jkrude added the bug Something isn't working label Jun 25, 2024
@vmoens
Copy link
Contributor

vmoens commented Jun 25, 2024

This code runs fine on nightlies

from tensordict import tensorclass, NonTensorData, NonTensorStack

data = torch.stack(
        [
            NonTensorData("a"),
            NonTensorData("a"),
        ]
    )
stack = NonTensorStack(
    *(NonTensorData("b"), NonTensorData("b")),
)
assert torch.stack([data, stack], dim=1).batch_size == (2,2)

Feel free to reopen if needed

@vmoens vmoens closed this as completed Jun 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants