[BUG] tensordict.pad_sequence
with a tensorclass
returns TensorDict
#771
Labels
bug
Something isn't working
Describe the bug
The function
tensordict.pad_sequence
is great for collating individual datapoints of varying length. Andtensorclass
es are good for typing hints. However, these two do not work together nicely becausepad_sequence
always returns aTensorDict
and never the given class.To Reproduce
Expected behavior
Ideally, return should be of the type
DataPoint
.Reason and Possible fixes
In
pad_sequence
,constructs the return object. Either a modification here could solve it, or right before returning the TensorDict, we can use the
tensorclass._from_tensordict
constructor (in our case,DataPoint._from_tensordict
) to comply with the original data type.Checklist
The text was updated successfully, but these errors were encountered: