You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a follow-up to #831 going into more detail about torch.stack. NonTensorData and NonTensorStack can't be stacked, even though they hold the data in compatible shapes. This is especially unpredictable due to the behavior of NonTensorData._stack_non_tensor which might return a NonTensorData if all elements are equal.
To Reproduce
Both data and stack hold strings of batch_size torch.Size([2]).
As torch.stack (implemented in _troch_func) calls NonTensorData._stack_non_tensor the data variable becomes a NonTensorData instead of a NonTensorStack as all elements are equal.
The last line will raise a NotImplemented error in torch_function as NonTensorData fails the check issubclass(t, (Tensor, cls, TensorDictBase))
E TypeError: Multiple dispatch failed for'torch.stack'; all __torch_function__ handlers returned NotImplemented:
E
E - tensor subclass <class 'tensordict.tensorclass.NonTensorData'>
E - tensor subclass <class 'tensordict.tensorclass.NonTensorStack'>
E
E For more information, try re-running with TORCH_LOGS=not_implemented
Expected behavior
torch.cat returns a (2,2) NonTensorStack in which both data and stack are stacked together.
Reasons
The behavior of NonTensorData._stack_non_tensor should be transparent to all other functionality, especially torch.stack should work with NonTensorData and NonTensorStack if both are of compatible batch sizes.
Describe the bug
This is a follow-up to #831 going into more detail about
torch.stack
.NonTensorData
andNonTensorStack
can't be stacked, even though they hold the data in compatible shapes. This is especially unpredictable due to the behavior ofNonTensorData._stack_non_tensor
which might return aNonTensorData
if all elements are equal.To Reproduce
Both
data
andstack
hold strings ofbatch_size torch.Size([2])
.As
torch.stack
(implemented in _troch_func) callsNonTensorData._stack_non_tensor
thedata
variable becomes aNonTensorData
instead of aNonTensorStack
as all elements are equal.The last line will raise a
NotImplemented
error in torch_function asNonTensorData
fails the checkissubclass(t, (Tensor, cls, TensorDictBase))
Expected behavior
torch.cat returns a (2,2)
NonTensorStack
in which both data and stack are stacked together.Reasons
The behavior of
NonTensorData._stack_non_tensor
should be transparent to all other functionality, especiallytorch.stack
should work withNonTensorData
andNonTensorStack
if both are of compatible batch sizes.System info
0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0
Installed using pip.
Checklist
The text was updated successfully, but these errors were encountered: