diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index ae6643555..f59b239f4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1145,11 +1145,8 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): def check_out(kwargs, result): - out = kwargs.get("out") - if out is result: - # No need to transform output - return True - return False + # No need to transform output if True + return kwargs.get("out") is result def deliver_result(self, result, kwargs): if result is None: