-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] tensordict.pad_sequence
silently ignores non-tensor attributes in tensorclass
es or TensorDict
s
#783
Comments
If you pass a list it will be cast to numpy ndarray (this is something we can re-consider in the future) from tensordict import pad_sequence, TensorDict
import torch
d1 = TensorDict({'a': torch.tensor([1, 1]), 'b': 'asd'})
d2 = TensorDict({'a': torch.tensor([2]), 'b': 'efg'})
print(d1['b'])
print(pad_sequence([d1, d2]))
print(pad_sequence([d1, d2])['b']) |
Tested this and indeed, it works! Thank you for the quick and neat fix 🙂 Would this change make its way into the next nightly release? |
Sorry I dropped the ball on this :( |
Hi again, I noticed that this doesn't work for MWE: @tensorclass
class Sample:
a: torch.Tensor
b: str
d1 = Sample(**{'a': torch.tensor([1, 1]), 'b': 'asd'}, batch_size=[])
d2 = Sample(**{'a': torch.tensor([2]), 'b': 'efg'}, batch_size=[])
print(pad_sequence([d1, d2])[1].b) # gives you 'asd' and not 'efg' |
Part of the issue is that import torch
import tensordict
from tensordict import tensorclass
@tensorclass
class Sample():
a: torch.Tensor
b: str >>> from example import *
>>> x = Sample(**{'a':torch.randn(10), 'b':'test'}, batch_size=[])
>>> x.get('b')
'test'
>>> x_td = x.to_tensordict()
>>> x_td.get('b')
NonTensorData(data=test, batch_size=torch.Size([]), device=None) For the tensordict/tensordict/functional.py Lines 166 to 169 in df9c196
and here: tensordict/tensordict/functional.py Lines 208 to 211 in df9c196
One option is to change Another option is to change >>> from example import *
>>> x = Sample(**{'a':torch.randn(10), 'b':'test'}, batch_size=[])
>>> y = Sample(**{'a':torch.randn(8), 'b':'another test'}, batch_size=[])
>>> p = tensordict.pad_sequence([x, y])
>>> p[0].b
'test'
>>> p[1].b
'another test' But then I noticed another problem. When tensordict/tensordict/functional.py Lines 218 to 219 in df9c196
>>> from example import *
>>> x = Sample(**{'a':torch.randn(10), 'b':'test'}, batch_size=[])
>>> y = Sample(**{'a':torch.randn(8), 'b':'another test'}, batch_size=[])
>>> p = tensordict.pad_sequence([x, y], return_mask=True)
Traceback (most recent call last):
File "/home/endoplasm/develop/tensordict-0/tensordict/functional.py", line 220, in pad_sequence
out.set(
File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 1438, in _set
return self.set(key[0], getattr(self, key[0]).set(key[1:], value))
File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 906, in _getattr
raise AttributeError(item)
AttributeError: masks I considered just using the same trick as before and call >>> p = tensordict.pad_sequence([x, y], return_mask=True)
>>> p
Sample(
a=Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.float32, is_shared=False),
b='test',
masks=TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False),
batch_size=torch.Size([2]),
device=None,
is_shared=False)
>>> p[0]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 1126, in _getitem
return _from_tensordict_with_copy(self, tensor_res) # device=res.device)
File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 538, in _from_tensordict_with_copy
return tc._from_tensordict(
File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 722, in wrapper
raise ValueError(
ValueError: Keys from the tensordict ({'a', 'b', 'masks'}) must correspond to the class attributes ({'a', 'b'}). So I have a few questions, @vmoens:
|
I opted to just patch |
Describe the bug
I have some
tensorclass
es that store an audio file, with some metadata including speaker id and utterance id. I would like to collate thesetensorclass
es to form a batch, however when I do so, the metadata is lost (the metadata from the first tensordict is kept for every item in the batch) and the user is not warned about this either.To Reproduce
Steps to reproduce the behavior.
Expected behavior
I should either get a properly joined tensordict, e.g.,
or
tensordict.pad_sequence
should warn the user that the metadata is being discarded.Screenshots
System info
Additional context
Add any other context about the problem here.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
The text was updated successfully, but these errors were encountered: