Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 18, 2024
1 parent e0e1ddc commit 5d26eed
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,6 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417
f"Expected a TensorDictBase instance but got {type(tensordict)}"
)
# Validating keys of tensordict
# tensordict = tensordict.copy()
tensor_keys = tensordict.keys()
# TODO: compile doesn't like set() over an arbitrary object
if is_dynamo_compiling():
Expand All @@ -891,10 +890,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417
exp_keys = set(cls.__expected_keys__)
if non_tensordict is not None:
nontensor_keys = set(non_tensordict.keys())
total_keys = tensor_keys.union(nontensor_keys)
else:
nontensor_keys = set()
non_tensordict = {}
total_keys = tensor_keys.union(nontensor_keys)
total_keys = tensor_keys
for key in nontensor_keys:
if key not in tensor_keys:
continue
Expand Down Expand Up @@ -922,7 +922,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417
tc.__dict__["_non_tensordict"] = non_tensordict
# since we aren't calling the dataclass init method, we need to manually check
# whether a __post_init__ method has been defined and invoke it if so
if hasattr(tc, "__post_init__"):
if hasattr(cls, "__post_init__"):
tc.__post_init__()
return tc
else:
Expand Down

0 comments on commit 5d26eed

Please sign in to comment.