Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 17, 2024
1 parent 0f5ea4c commit 2ae42b0
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@
mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"


@pytest.fixture
def device_fixture():
device = torch.get_default_device()
if torch.cuda.is_available():
torch.set_default_device(torch.device("cuda:0"))
elif torch.backends.mps.is_available():
torch.set_default_device(torch.device("mps:0"))
yield
torch.set_default_device(device)


def _compare_tensors_identity(td0, td1):
if isinstance(td0, LazyStackedTensorDict):
if not isinstance(td1, LazyStackedTensorDict):
Expand Down Expand Up @@ -243,7 +254,7 @@ def test_batchsize_reset(self):
td_u.to_tensordict().batch_size = [1]

@pytest.mark.parametrize("count_duplicates", [False, True])
def test_bytes(self, count_duplicates):
def test_bytes(self, count_duplicates, device_fixture):
tensor = torch.zeros(3)
tensor_with_grad = torch.ones(3, requires_grad=True)
(tensor_with_grad + 1).sum().backward()
Expand All @@ -267,7 +278,7 @@ def test_bytes(self, count_duplicates):
else:
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 0

def test_depth(ggself):
def test_depth(self):
td = TensorDict({"a": {"b": {"c": {"d": 0}, "e": 0}, "f": 0}, "g": 0}).lock_()
assert td.depth == 3
with td.unlock_():
Expand Down

0 comments on commit 2ae42b0

Please sign in to comment.