diff --git a/tensordict/_td.py b/tensordict/_td.py index 4387839b5..255cca40a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -4210,7 +4210,7 @@ def _iter(): if self.leaves_only: for key in self._keys(): target_class = self.tensordict.entry_class(key) - if _is_tensor_collection(target_class): + if not self.is_leaf(target_class): continue yield key else: @@ -4239,10 +4239,12 @@ def _iter_helper( # For lazy stacks value = value[0] cls = type(value) - is_leaf = self.is_leaf(cls) - if self.include_nested and not is_leaf: + is_tc = _is_tensor_collection(cls) + if self.include_nested and is_tc: yield from self._iter_helper(value, prefix=full_key) + is_leaf = self.is_leaf(cls) if not self.leaves_only or is_leaf: + print(key, "is leaf", is_leaf) yield full_key def _combine_keys(self, prefix: tuple | None, key: NestedKey) -> tuple: diff --git a/tensordict/base.py b/tensordict/base.py index 358cae1b1..b65241770 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5570,30 +5570,27 @@ def items( Defaults to ``False``. """ - if is_leaf is None: - is_leaf = _default_is_leaf + if sort: + yield from sorted( + self.items(include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf), + key=lambda item: ( + item[0] if isinstance(item[0], str) else ".".join(item[0]) + ), + ) + else: + + if is_leaf is None: + is_leaf = _default_is_leaf - def _items(): - if include_nested and leaves_only: + + if include_nested: # check the conditions once only for k in self.keys(): val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - else: + cls = type(val) + if not leaves_only or is_leaf(cls): yield k, val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield k, val - if not is_leaf(type(val)): + if _is_tensor_collection(cls): yield from ( (_unravel_key_to_tuple((k, _key)), _val) for _key, _val in val.items( @@ -5611,16 +5608,6 @@ def _items(): for k in self.keys(): yield k, self._get_str(k, NO_DEFAULT) - if sort: - yield from sorted( - _items(), - key=lambda item: ( - item[0] if isinstance(item[0], str) else ".".join(item[0]) - ), - ) - else: - yield from _items() - def non_tensor_items(self, include_nested: bool = False): """Returns all non-tensor leaves, maybe recursively.""" return tuple( @@ -5657,27 +5644,23 @@ def values( Defaults to ``False``. """ - if is_leaf is None: - is_leaf = _default_is_leaf - def _values(): + if sort: + for k, value in self.items(include_nested, leaves_only, is_leaf, sort=sort): + yield value + else: + + if is_leaf is None: + is_leaf = _default_is_leaf + # check the conditions once only - if include_nested and leaves_only: + if include_nested: for k in self.keys(): val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from val.values( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - else: + cls = type(val) + if not leaves_only or is_leaf(cls): yield val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield val - if not is_leaf(type(val)): + if include_nested and _is_tensor_collection(cls): yield from val.values( include_nested=include_nested, leaves_only=leaves_only, @@ -5692,11 +5675,6 @@ def _values(): for k in self.keys(sort=sort): yield self._get_str(k, NO_DEFAULT) - if not sort or not include_nested: - yield from _values() - else: - for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort): - yield value @cache # noqa: B019 def _values_list( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8906eefd4..0e11a734d 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -569,7 +569,7 @@ def __torch_function__( setattr(cls, method_name, getattr(TensorDict, method_name)) for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): - setattr(cls, method_name, _wrap_td_method(method_name)) + setattr(cls, method_name, _wrap_td_method(method_name, force_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) @@ -857,7 +857,7 @@ def get_parent_locals(cls, localns=localns): cls._type_hints = None -def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 +def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa: D417 """Tensor class wrapper to instantiate a new tensor class object. Args: @@ -865,7 +865,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 non_tensordict (dict): Dictionary with non-tensor and nested tensor class objects """ - if not isinstance(tensordict, TensorDictBase): + if safe and not isinstance(tensordict, TensorDictBase): raise RuntimeError( f"Expected a TensorDictBase instance but got {type(tensordict)}" ) @@ -890,10 +890,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 exp_keys = set(cls.__expected_keys__) if non_tensordict is not None: nontensor_keys = set(non_tensordict.keys()) + total_keys = tensor_keys.union(nontensor_keys) else: nontensor_keys = set() non_tensordict = {} - total_keys = tensor_keys.union(nontensor_keys) + total_keys = tensor_keys for key in nontensor_keys: if key not in tensor_keys: continue @@ -917,11 +918,12 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 # empty tensordict and writing values to it. we can skip this because we already # have a tensordict to use as the underlying tensordict tc = cls.__new__(cls) - tc.__dict__["_tensordict"] = tensordict - tc.__dict__["_non_tensordict"] = non_tensordict + tc.__dict__.update( + {"_tensordict": tensordict, "_non_tensordict": non_tensordict} + ) # since we aren't calling the dataclass init method, we need to manually check # whether a __post_init__ method has been defined and invoke it if so - if hasattr(tc, "__post_init__"): + if hasattr(cls, "__post_init__"): tc.__post_init__() return tc else: @@ -1142,7 +1144,28 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): +def _wrap_td_method( + funcname, *, copy_non_tensor=False, no_wrap=False, force_wrap=False +): + def deliver_result(self, result, kwargs): + if result is None: + return + if (force_wrap or isinstance(result, TensorDictBase)) and kwargs.get( + "out" + ) is not result: + if not is_dynamo_compiling(): + non_tensordict = super(type(self), self).__getattribute__( + "_non_tensordict" + ) + else: + non_tensordict = self._non_tensordict + non_tensordict = dict(non_tensordict) + if copy_non_tensor and non_tensordict: + # use tree_map to copy + non_tensordict = tree_map(lambda x: x, non_tensordict) + return self._from_tensordict(result, non_tensordict, safe=False) + return result + def wrapped_func(self, *args, **kwargs): if not is_dynamo_compiling(): td = super(type(self), self).__getattribute__("_tensordict") @@ -1154,34 +1177,12 @@ def wrapped_func(self, *args, **kwargs): if no_wrap: return result - def check_out(kwargs, result): - out = kwargs.get("out") - if out is result: - # No need to transform output - return True - return False - if result is td: return self - def deliver_result(result): - if isinstance(result, TensorDictBase) and not check_out(kwargs, result): - if not is_dynamo_compiling(): - non_tensordict = super(type(self), self).__getattribute__( - "_non_tensordict" - ) - else: - non_tensordict = self._non_tensordict - non_tensordict = dict(non_tensordict) - if copy_non_tensor: - # use tree_map to copy - non_tensordict = tree_map(lambda x: x, non_tensordict) - return self._from_tensordict(result, non_tensordict) - return result - if isinstance(result, tuple): - return tuple(deliver_result(r) for r in result) - return deliver_result(result) + return tuple(deliver_result(self, r, kwargs) for r in result) + return deliver_result(self, result, kwargs) return wrapped_func