Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tensordict.pad_sequence silently ignores non-tensor attributes in tensorclasses or TensorDicts #783

Closed
3 tasks done
egaznep opened this issue May 18, 2024 · 6 comments · Fixed by #784 or #884
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@egaznep
Copy link
Contributor

egaznep commented May 18, 2024

Describe the bug

I have some tensorclasses that store an audio file, with some metadata including speaker id and utterance id. I would like to collate these tensorclasses to form a batch, however when I do so, the metadata is lost (the metadata from the first tensordict is kept for every item in the batch) and the user is not warned about this either.

To Reproduce

Steps to reproduce the behavior.

from tensordict import pad_sequence, TensorDict

d1 = TensorDict({'a': torch.tensor([0]), 'b': ['asd']})
d2 = TensorDict({'a': torch.tensor([0]), 'b': ['efg']})

pad_sequence([d1, d2])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        b: NonTensorData(data=asd, batch_size=torch.Size([2]), device=None)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Expected behavior

I should either get a properly joined tensordict, e.g.,

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        b: NonTensorData(data=['asd', 'efg'], batch_size=torch.Size([2]), device=None)}, # in a list same shape as the batch_size
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

or tensordict.pad_sequence should warn the user that the metadata is being discarded.

Screenshots

System info

tensordict-nightly              2024.5.18

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egaznep egaznep added the bug Something isn't working label May 18, 2024
@vmoens
Copy link
Contributor

vmoens commented May 20, 2024

If you pass a list it will be cast to numpy ndarray (this is something we can re-consider in the future)
But if you use a plain string the following code will do what you want I think (given #784)

from tensordict import pad_sequence, TensorDict
import torch

d1 = TensorDict({'a': torch.tensor([1, 1]), 'b': 'asd'})
d2 = TensorDict({'a': torch.tensor([2]), 'b': 'efg'})

print(d1['b'])
print(pad_sequence([d1, d2]))
print(pad_sequence([d1, d2])['b'])

@vmoens vmoens linked a pull request May 20, 2024 that will close this issue
@egaznep
Copy link
Contributor Author

egaznep commented May 20, 2024

If you pass a list it will be cast to numpy ndarray (this is something we can re-consider in the future) But if you use a plain string the following code will do what you want I think (given #784)

from tensordict import pad_sequence, TensorDict
import torch

d1 = TensorDict({'a': torch.tensor([1, 1]), 'b': 'asd'})
d2 = TensorDict({'a': torch.tensor([2]), 'b': 'efg'})

print(d1['b'])
print(pad_sequence([d1, d2]))
print(pad_sequence([d1, d2])['b'])

Tested this and indeed, it works! Thank you for the quick and neat fix 🙂 Would this change make its way into the next nightly release?

@vmoens
Copy link
Contributor

vmoens commented May 22, 2024

Sorry I dropped the ball on this :(
The PR is almost ready but there's some non trivial issue with Peristent (H5) tensordicts that need to be solved before merging. I'll do my best to do it today!

@egaznep
Copy link
Contributor Author

egaznep commented May 23, 2024

Hi again,

I noticed that this doesn't work for tensorclasses.

MWE:

@tensorclass
class Sample:
    a: torch.Tensor
    b: str

d1 = Sample(**{'a': torch.tensor([1, 1]), 'b': 'asd'}, batch_size=[])
d2 = Sample(**{'a': torch.tensor([2]), 'b': 'efg'}, batch_size=[])
print(pad_sequence([d1, d2])[1].b) # gives you 'asd' and not 'efg'

@vmoens vmoens reopened this May 25, 2024
@kurtamohler kurtamohler assigned kurtamohler and unassigned vmoens Jul 12, 2024
@kurtamohler
Copy link
Collaborator

kurtamohler commented Jul 12, 2024

Part of the issue is that TensorDict.get() and <tensorclass>.get() return different types. Here's an example:

import torch
import tensordict
from tensordict import tensorclass

@tensorclass
class Sample():
    a: torch.Tensor
    b: str 
>>> from example import *
>>> x = Sample(**{'a':torch.randn(10), 'b':'test'}, batch_size=[])
>>> x.get('b')
'test'
>>> x_td = x.to_tensordict()
>>> x_td.get('b')
NonTensorData(data=test, batch_size=torch.Size([]), device=None)

For the TensorDict, .get('b') returns a NonTensorData object, but for the Sample, .get('b') returns a str. pad_sequence hands the output of the .get() call over to is_non_tensor() here:

item = td.get(key)
list_of_dicts[i][key] = item
if is_non_tensor(item):
continue

and here:

item0 = list_of_dicts[0][key]
if is_non_tensor(item0):
out.set(key, torch.stack([d[key] for d in list_of_dicts]))
continue

is_non_tensor will only return True if it is given a NonTensorData or any other object that has the attr _is_non_tensor=True. It returns False if it is given a str.

One option is to change <tensorclass>.get()'s return type to match TensorDict.get(). But I'm not sure if it's okay to make a change like that--maybe the current behavior of <tensorclass>.get() is intentional.

Another option is to change pad_sequence to check if each tensordict given to it is a tensorclass, and then call <tensorclass>._tensordict.get() if so. I tried that, and it fixed the problem in comment #783 (comment).

>>> from example import *
>>> x = Sample(**{'a':torch.randn(10), 'b':'test'}, batch_size=[])
>>> y = Sample(**{'a':torch.randn(8), 'b':'another test'}, batch_size=[])
>>> p = tensordict.pad_sequence([x, y])
>>> p[0].b
'test'
>>> p[1].b
'another test'

But then I noticed another problem. When pad_sequence is called with return_mask=True, a new attribute "masks" is added to the output tensordict. But if the tensordict is a tensorclass, then the following .set() call fails because the new attribute "masks" is not part of Sample's signature:

out.set(
key,

>>> from example import *
>>> x = Sample(**{'a':torch.randn(10), 'b':'test'}, batch_size=[])
>>> y = Sample(**{'a':torch.randn(8), 'b':'another test'}, batch_size=[])
>>> p = tensordict.pad_sequence([x, y], return_mask=True)
Traceback (most recent call last):
  File "/home/endoplasm/develop/tensordict-0/tensordict/functional.py", line 220, in pad_sequence
    out.set(
  File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 1438, in _set
    return self.set(key[0], getattr(self, key[0]).set(key[1:], value))
  File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 906, in _getattr
    raise AttributeError(item)
AttributeError: masks

I considered just using the same trick as before and call out._tensordict.set() in the tensorclass case. That does prevent any errors from raising within the pad_sequence call, but then accessing one of the items from the output raises an error, again because Sample's signature doesn't have the masks attribute.

>>> p = tensordict.pad_sequence([x, y], return_mask=True)
>>> p
Sample(
    a=Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.float32, is_shared=False),
    b='test',
    masks=TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
>>> p[0]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 1126, in _getitem
    return _from_tensordict_with_copy(self, tensor_res)  # device=res.device)
  File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 538, in _from_tensordict_with_copy
    return tc._from_tensordict(
  File "/home/endoplasm/develop/tensordict-0/tensordict/tensorclass.py", line 722, in wrapper
    raise ValueError(
ValueError: Keys from the tensordict ({'a', 'b', 'masks'}) must correspond to the class attributes ({'a', 'b'}).

So I have a few questions, @vmoens:

  • Should we update <tensorclass>.get() to return the same types that TensorDict.get() returns? Or would it be better to avoid changing the tensorclass interface and instead make pad_sequence handle tensorclasss and TensorDict slightly differently as described above?

  • What should we do about the return_mask=True <tensorclass>.set() issue? A few potential options:

    1. Make tensorclass._getitem() just work when it finds attributes that aren't in the class signature. I suppose it would have to set those attributes on the output using output._tensordict.set().
    2. Make <tensorclass>.set() work for keys that are not in the class signature.
    3. Don't support return_mask=True for tensorclass--raise an error.
    4. If return_mask=True, make the return type of pad_sequence always be a TensorDict so that we can include the "masks" attribute.

@kurtamohler
Copy link
Collaborator

I opted to just patch pad_sequence, without changing tensorclass._getitem(), and to just raise an error if return_mask=True when given a list of tensorclasses.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants