From 2ae42b0f699f70ddfdfedae9cc9159ee850e1e8d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 10:21:09 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_tensordict.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 1183bb828..0f1f65b5d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -132,6 +132,17 @@ mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn" +@pytest.fixture +def device_fixture(): + device = torch.get_default_device() + if torch.cuda.is_available(): + torch.set_default_device(torch.device("cuda:0")) + elif torch.backends.mps.is_available(): + torch.set_default_device(torch.device("mps:0")) + yield + torch.set_default_device(device) + + def _compare_tensors_identity(td0, td1): if isinstance(td0, LazyStackedTensorDict): if not isinstance(td1, LazyStackedTensorDict): @@ -243,7 +254,7 @@ def test_batchsize_reset(self): td_u.to_tensordict().batch_size = [1] @pytest.mark.parametrize("count_duplicates", [False, True]) - def test_bytes(self, count_duplicates): + def test_bytes(self, count_duplicates, device_fixture): tensor = torch.zeros(3) tensor_with_grad = torch.ones(3, requires_grad=True) (tensor_with_grad + 1).sum().backward() @@ -267,7 +278,7 @@ def test_bytes(self, count_duplicates): else: assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 0 - def test_depth(ggself): + def test_depth(self): td = TensorDict({"a": {"b": {"c": {"d": 0}, "e": 0}, "f": 0}, "g": 0}).lock_() assert td.depth == 3 with td.unlock_():