Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 24, 2024
1 parent 9b64bcd commit 43bad86
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 85 deletions.
8 changes: 5 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -4210,7 +4210,7 @@ def _iter():
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
if not self.is_leaf(target_class):
continue
yield key
else:
Expand Down Expand Up @@ -4239,10 +4239,12 @@ def _iter_helper(
# For lazy stacks
value = value[0]
cls = type(value)
is_leaf = self.is_leaf(cls)
if self.include_nested and not is_leaf:
is_tc = _is_tensor_collection(cls)
if self.include_nested and is_tc:
yield from self._iter_helper(value, prefix=full_key)
is_leaf = self.is_leaf(cls)
if not self.leaves_only or is_leaf:
print(key, "is leaf", is_leaf)
yield full_key

def _combine_keys(self, prefix: tuple | None, key: NestedKey) -> tuple:
Expand Down
78 changes: 28 additions & 50 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5570,30 +5570,27 @@ def items(
Defaults to ``False``.
"""
if is_leaf is None:
is_leaf = _default_is_leaf
if sort:
yield from sorted(
self.items(include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:

if is_leaf is None:
is_leaf = _default_is_leaf

def _items():
if include_nested and leaves_only:

if include_nested:
# check the conditions once only
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
else:
cls = type(val)
if not leaves_only or is_leaf(cls):
yield k, val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield k, val
if not is_leaf(type(val)):
if _is_tensor_collection(cls):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
Expand All @@ -5611,16 +5608,6 @@ def _items():
for k in self.keys():
yield k, self._get_str(k, NO_DEFAULT)

if sort:
yield from sorted(
_items(),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:
yield from _items()

def non_tensor_items(self, include_nested: bool = False):
"""Returns all non-tensor leaves, maybe recursively."""
return tuple(
Expand Down Expand Up @@ -5657,27 +5644,23 @@ def values(
Defaults to ``False``.
"""
if is_leaf is None:
is_leaf = _default_is_leaf

def _values():
if sort:
for k, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value
else:

if is_leaf is None:
is_leaf = _default_is_leaf

# check the conditions once only
if include_nested and leaves_only:
if include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
else:
cls = type(val)
if not leaves_only or is_leaf(cls):
yield val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield val
if not is_leaf(type(val)):
if include_nested and _is_tensor_collection(cls):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
Expand All @@ -5692,11 +5675,6 @@ def _values():
for k in self.keys(sort=sort):
yield self._get_str(k, NO_DEFAULT)

if not sort or not include_nested:
yield from _values()
else:
for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value

@cache # noqa: B019
def _values_list(
Expand Down
65 changes: 33 additions & 32 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def __torch_function__(
setattr(cls, method_name, getattr(TensorDict, method_name))
for method_name in _FALLBACK_METHOD_FROM_TD:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name))
setattr(cls, method_name, _wrap_td_method(method_name, force_wrap=True))
for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True))
Expand Down Expand Up @@ -857,15 +857,15 @@ def get_parent_locals(cls, localns=localns):
cls._type_hints = None


def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417
def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa: D417
"""Tensor class wrapper to instantiate a new tensor class object.
Args:
tensordict (TensorDict): Dictionary of tensor types
non_tensordict (dict): Dictionary with non-tensor and nested tensor class objects
"""
if not isinstance(tensordict, TensorDictBase):
if safe and not isinstance(tensordict, TensorDictBase):
raise RuntimeError(
f"Expected a TensorDictBase instance but got {type(tensordict)}"
)
Expand All @@ -890,10 +890,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417
exp_keys = set(cls.__expected_keys__)
if non_tensordict is not None:
nontensor_keys = set(non_tensordict.keys())
total_keys = tensor_keys.union(nontensor_keys)
else:
nontensor_keys = set()
non_tensordict = {}
total_keys = tensor_keys.union(nontensor_keys)
total_keys = tensor_keys
for key in nontensor_keys:
if key not in tensor_keys:
continue
Expand All @@ -917,11 +918,12 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417
# empty tensordict and writing values to it. we can skip this because we already
# have a tensordict to use as the underlying tensordict
tc = cls.__new__(cls)
tc.__dict__["_tensordict"] = tensordict
tc.__dict__["_non_tensordict"] = non_tensordict
tc.__dict__.update(
{"_tensordict": tensordict, "_non_tensordict": non_tensordict}
)
# since we aren't calling the dataclass init method, we need to manually check
# whether a __post_init__ method has been defined and invoke it if so
if hasattr(tc, "__post_init__"):
if hasattr(cls, "__post_init__"):
tc.__post_init__()
return tc
else:
Expand Down Expand Up @@ -1142,7 +1144,28 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
return wrapper


def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False):
def _wrap_td_method(
funcname, *, copy_non_tensor=False, no_wrap=False, force_wrap=False
):
def deliver_result(self, result, kwargs):
if result is None:
return
if (force_wrap or isinstance(result, TensorDictBase)) and kwargs.get(
"out"
) is not result:
if not is_dynamo_compiling():
non_tensordict = super(type(self), self).__getattribute__(
"_non_tensordict"
)
else:
non_tensordict = self._non_tensordict
non_tensordict = dict(non_tensordict)
if copy_non_tensor and non_tensordict:
# use tree_map to copy
non_tensordict = tree_map(lambda x: x, non_tensordict)
return self._from_tensordict(result, non_tensordict, safe=False)
return result

def wrapped_func(self, *args, **kwargs):
if not is_dynamo_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
Expand All @@ -1154,34 +1177,12 @@ def wrapped_func(self, *args, **kwargs):
if no_wrap:
return result

def check_out(kwargs, result):
out = kwargs.get("out")
if out is result:
# No need to transform output
return True
return False

if result is td:
return self

def deliver_result(result):
if isinstance(result, TensorDictBase) and not check_out(kwargs, result):
if not is_dynamo_compiling():
non_tensordict = super(type(self), self).__getattribute__(
"_non_tensordict"
)
else:
non_tensordict = self._non_tensordict
non_tensordict = dict(non_tensordict)
if copy_non_tensor:
# use tree_map to copy
non_tensordict = tree_map(lambda x: x, non_tensordict)
return self._from_tensordict(result, non_tensordict)
return result

if isinstance(result, tuple):
return tuple(deliver_result(r) for r in result)
return deliver_result(result)
return tuple(deliver_result(self, r, kwargs) for r in result)
return deliver_result(self, result, kwargs)

return wrapped_func

Expand Down

0 comments on commit 43bad86

Please sign in to comment.