From e489ad9fb1d843efa7a63bedcdb454b0d08290cc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Oct 2024 14:37:58 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- benchmarks/compile/compile_td_test.py | 2 +- benchmarks/compile/tensordict_nn_test.py | 2 +- test/test_compile.py | 2 +- test/test_distributed.py | 4 +--- test/test_tensordict.py | 2 +- 6 files changed, 6 insertions(+), 8 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index db1f188fa..0e20aae75 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -11,7 +11,7 @@ from tensordict import TensorDict -TORCH_VERSION = version.parse(".".join(torch.__version__.split(".")[:3])) +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @pytest.fixture diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 4b1b1475f..3a1ef0ee1 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -10,7 +10,7 @@ from tensordict import LazyStackedTensorDict, tensorclass, TensorDict from torch.utils._pytree import tree_map -TORCH_VERSION = version.parse(".".join(torch.__version__.split(".")[:3])) +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @tensorclass diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py index 1dc348216..7828c29f6 100644 --- a/benchmarks/compile/tensordict_nn_test.py +++ b/benchmarks/compile/tensordict_nn_test.py @@ -15,7 +15,7 @@ from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq -TORCH_VERSION = version.parse(".".join(torch.__version__.split(".")[:3])) +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) sys.setrecursionlimit(10000) diff --git a/test/test_compile.py b/test/test_compile.py index aea9232cf..4b9ebdca8 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -33,7 +33,7 @@ from torch.utils._pytree import SUPPORTED_NODES, tree_map -TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) _has_onnx = importlib.util.find_spec("onnxruntime", None) is not None diff --git a/test/test_distributed.py b/test/test_distributed.py index 64e8a5616..2a30b1593 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -108,9 +108,7 @@ def test_fsdp_module(self, tmpdir): # not using TorchVersion to make the comparison work with dev -TORCH_VERSION = version.parse( - ".".join(map(str, version.parse(torch.__version__).release)) -) +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @pytest.mark.skipif( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 099d94b25..b1775a897 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -92,7 +92,7 @@ _has_h5py = True except ImportError: _has_h5py = False -TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) _has_onnx = importlib.util.find_spec("onnxruntime", None) is not None