Skip to content

Commit

Permalink
[BugFix] Fix _make_dtype_promotion backward compat (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 27, 2024
1 parent d89e5c0 commit 0fa000c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,10 +2316,14 @@ def is_namedtuple_class(cls):


def _make_dtype_promotion(func):
dtype = getattr(torch, func.__name__)
dtype = getattr(torch, func.__name__, None)

@wraps(func)
def new_func(self):
if dtype is None:
raise NotImplementedError(
f"Your pytorch version {torch.__version__} does not support {dtype}."
)
return self._fast_apply(lambda x: x.to(dtype), propagate_lock=True)

new_func.__doc__ = rf"""Casts all tensors to ``{str(dtype)}``."""
Expand Down

2 comments on commit 0fa000c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 0fa000c Previous: d89e5c0 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 73280.5744823974 iter/sec (stddev: 0.0000012987925898236784) 234623.96405536495 iter/sec (stddev: 3.2008448572957416e-7) 3.20
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 73400.21816862827 iter/sec (stddev: 0.000014061281227380462) 233388.7435328727 iter/sec (stddev: 3.031576058476586e-7) 3.18

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 0fa000c Previous: d89e5c0 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 103064.35514792756 iter/sec (stddev: 6.814546398409041e-7) 328572.123838064 iter/sec (stddev: 3.276682220464234e-7) 3.19
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 103021.10225791212 iter/sec (stddev: 6.639470686105523e-7) 331211.2932021009 iter/sec (stddev: 3.1183069881043673e-7) 3.21

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.