Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 25, 2024
1 parent 4358e60 commit f70288a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
7 changes: 3 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
_index_preserve_data_ptr,
_infer_size_impl,
_is_shared,
_is_tensorclass,
_KEY_ERROR,
_LOCK_ERROR,
_NON_STR_KEY_ERR,
Expand Down Expand Up @@ -612,7 +611,7 @@ def _quick_set(swap_dict, swap_td):
return TensorDict._new_unsafe(_swap, batch_size=[])

def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
if is_tensorclass(other):
return other != self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
Expand All @@ -636,7 +635,7 @@ def __ne__(self, other: object) -> T | bool:
return True

def __xor__(self, other: object) -> T | bool:
if _is_tensorclass(other):
if is_tensorclass(other):
return other ^ self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
Expand All @@ -660,7 +659,7 @@ def __xor__(self, other: object) -> T | bool:
return True

def __or__(self, other: object) -> T | bool:
if _is_tensorclass(other):
if is_tensorclass(other):
return other | self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
Expand Down
3 changes: 3 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_is_tensorclass,
_LOCK_ERROR,
_td_fields,
_TENSORCLASS_MEMO,
_unravel_key_to_tuple,
_zip_strict,
DeviceType,
Expand Down Expand Up @@ -499,6 +500,8 @@ def __torch_function__(
_is_non_tensor = getattr(cls, "_is_non_tensor", False)

cls = dataclass(cls, frozen=frozen)
_TENSORCLASS_MEMO[cls] = True

expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__)

for attr in expected_keys:
Expand Down
23 changes: 21 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,8 +854,16 @@ def is_tensorclass(obj: type | Any) -> bool:
return _is_tensorclass(cls)


_TENSORCLASS_MEMO = {}


def _is_tensorclass(cls: type) -> bool:
return getattr(cls, "_is_tensorclass", False)
out = _TENSORCLASS_MEMO.get(cls, None)
if out is None:
out = getattr(cls, "_is_tensorclass", False)
if not is_dynamo_compiling():
_TENSORCLASS_MEMO[cls] = out
return out


class implement_for:
Expand Down Expand Up @@ -2353,8 +2361,19 @@ def is_non_tensor(data):
return getattr(type(data), "_is_non_tensor", False)


_NON_TENSOR_MEMO = {}


def _is_non_tensor(cls: type):
return getattr(cls, "_is_non_tensor", False)
out = None
is_dynamo = is_dynamo_compiling()
if not is_dynamo:
out = _NON_TENSOR_MEMO.get(cls)
if out is None:
out = getattr(cls, "_is_non_tensor", False)
if not is_dynamo:
_NON_TENSOR_MEMO[cls] = out
return out


class KeyDependentDefaultDict(collections.defaultdict):
Expand Down

0 comments on commit f70288a

Please sign in to comment.