Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 18, 2024
1 parent 9c7ed32 commit 26affa0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def test_to(
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
td = td.consolidate(pin_memory=pin_mem, set_on_tensor=True)
td = td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

Expand All @@ -127,12 +127,12 @@ def test_to_njt(
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
njt_td = njt_td.consolidate(pin_memory=pin_mem, set_on_tensor=True)
njt_td = njt_td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

Expand Down

0 comments on commit 26affa0

Please sign in to comment.