Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] NonTensorData behavior with equal data is not transparent to the rest of the library #831

Closed
3 tasks done
jkrude opened this issue Jun 24, 2024 · 2 comments · Fixed by #837
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@jkrude
Copy link

jkrude commented Jun 24, 2024

Describe the bug

NonTensorData._stack_non_tensor(...) will create a NonTensorData object instead of a NonTensorStack if all elements passed to the method are equal.
There is nothing wrong with the idea, however even though a call to .tolist() will produce the same output many parts of the library (and especially torchrl) can't handle this behavior.

torch.stack is not defined for NonTensorStack and NonTensorData

data = torch.stack(
        [
            NonTensorData("a"),
            NonTensorData("a"),
        ]
    )
    stack = NonTensorStack(*(NonTensorData("b"), NonTensorData("b")))
    assert torch.stack([data, stack], dim=1).batch_size == (2,2)

The code with throw an exception

TypeError: Multiple dispatch failed for 'torch.stack'; all __torch_function__ handlers returned NotImplemented:
       
         - tensor subclass <class 'tensordict.tensorclass.NonTensorData'>
         - tensor subclass <class 'tensordict.tensorclass.NonTensorStack'>
       

Even though both data and stack have the same batch_size and the output of .tolist() is a list of two elements, the two elements cannot be concatenated. This happens in a more practical example with torchrl when the tensordicts of each time step get stacked along the time axis.
For example, if my observation is a list of two non-tensor items and by chance in any of the time steps the two items are equal. Then the tensordict will store a NonTensorData object for this time step, which will trigger the above-mentioned torch.cat exception at the end of the rollout.
(The same problem occurs for torch.cat).

LazyStackedTensorDict.where fails with equivalent NonTensorData even though it is passed in as a NonTensorStack

condition = torch.tensor([True, False])
    tensor = NonTensorStack(
        *[NonTensorData(["a"]), NonTensorData(["a"])], batch_size=(2,)
    )
    other = NonTensorStack(
        *[NonTensorData(["a"]), NonTensorData(["a"])], batch_size=(2,)
    )
    out = NonTensorStack(*[NonTensorData(["a"]), NonTensorData(["a"])], batch_size=(2,))
    result = tensor.where(condition=condition, other=other, out=out, pad=0)

The code will produce an exception:

../../../venv/lib/python3.10/site-packages/tensordict/_lazy.py:2255: in where
    out.update(result)
../../../venv/lib/python3.10/site-packages/tensordict/tensorclass.py:2430: in update
    return self._update(
../../../venv/lib/python3.10/site-packages/tensordict/tensorclass.py:2470: in _update
    return self.update(
../../../venv/lib/python3.10/site-packages/tensordict/tensorclass.py:2430: in update
    return self._update(
../../../venv/lib/python3.10/site-packages/tensordict/tensorclass.py:2498: in _update
    leaf_dest._update(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = NonTensorData(data=['a'], batch_size=torch.Size([]), device=None)
input_dict_or_td = NonTensorStack(
    ['a'],
    batch_size=torch.Size([1]),
    device=None)
clone = False, inplace = False

    def _update(
        self,
        input_dict_or_td: dict[str, CompatibleType] | T,
        clone: bool = False,
        inplace: bool = False,
        *,
        keys_to_update: Sequence[NestedKey] | None = None,
        break_on_memmap: bool = None,
    ) -> T:
        if isinstance(input_dict_or_td, NonTensorData):
            data = input_dict_or_td.data
            if inplace and self._tensordict._is_shared:
                _update_shared_nontensor(self._non_tensordict["data"], data)
                return self
            elif inplace and self._is_memmap:
                _is_memmaped_from_above = self._is_memmaped_from_above()
                if break_on_memmap is None:
                    global _BREAK_ON_MEMMAP
                    break_on_memmap = _BREAK_ON_MEMMAP
                if _is_memmaped_from_above and break_on_memmap:
                    raise RuntimeError(
                        "Cannot update a leaf NonTensorData from a memmaped parent NonTensorStack. "
                        "To update this leaf node, please update the NonTensorStack with the proper index."
                    )
                share_non_tensor = self._metadata["_share_non_tensor"]
                if share_non_tensor:
                    _update_shared_nontensor(self._non_tensordict["data"], data)
                else:
                    self._non_tensordict["data"] = data
                # Force json update by setting is memmap to False
                if not _is_memmaped_from_above and "memmap_prefix" in self._metadata:
                    self._tensordict._is_memmap = False
                    self._memmap_(
                        prefix=self._metadata["memmap_prefix"],
                        copy_existing=False,
                        executor=None,
                        futures=None,
                        inplace=True,
                        like=False,
                        share_non_tensor=share_non_tensor,
                    )
                return self
            elif not inplace and self.is_locked:
                raise RuntimeError(_LOCK_ERROR)
            if clone:
                data = deepcopy(data)
            self.data = data
        elif isinstance(input_dict_or_td, NonTensorStack):
>           raise ValueError(
                "Cannot update a NonTensorData object with a NonTensorStack. Call `non_tensor_data.maybe_to_stack()` "
                "before calling update()."
E               ValueError: Cannot update a NonTensorData object with a NonTensorStack. Call `non_tensor_data.maybe_to_stack()` before calling update().

../../../venv/lib/python3.10/site-packages/tensordict/tensorclass.py:1980: ValueError

Again, the problem comes down to a NonTensorData and a NonTensorStack of the same batch_size but containing equal elements.
Note that the input for the where(...) function are all NonTensorStacks. However, within LazyStackedTensorDict.where after the condition is passed further down the results are reconstructed using maybe_dense_stack which will return a NonTensorData:

result = LazyStackedTensorDict.maybe_dense_stack(
                [
                    td.where(cond, other, pad=pad)
                    for td, cond in zip(self.tensordicts, condition)
                ],
                self.stack_dim,
            )

Again, I stumbled over this unexpected behavior by using torchrl. In particular, using environment.rollout(5, break_when_any_done=False)in which at some point one but not all batch-entries is done. InEnvBase._update_during_resetthe call tonode.where(~reset, other=node_reset, out=node, pad=0)` will trigger the above-

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@jkrude jkrude added the bug Something isn't working label Jun 24, 2024
@vmoens
Copy link
Contributor

vmoens commented Jun 25, 2024

Thanks for this

Can I ask you to open separate issues for each of these with a reprod example, and keep this issue as a tracker?

@jkrude
Copy link
Author

jkrude commented Jun 25, 2024

With my limited knowledge of the library, I would guess that the root cause to all mentioned problems is the behavior of NonTensorData._stack_non_tensor for equal elements. The returned NonTensorData is simply not interchangeably useable with NonTensorStack. This is especially the case as NonTensorData and NonTensorStack are not of the same type, but many points in the codebase have type-dependent behavior 🤔.
I try to open issues for the other mentioned issues throughout the day.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants