Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 11, 2024
1 parent 8b35ae6 commit 3c22416
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
39 changes: 30 additions & 9 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from torch.utils._pytree import SUPPORTED_NODES, tree_map

try:
from torch.utils._pytree import tree_leaves
from torch.utils._pytree import tree_flatten, tree_leaves, tree_unflatten
except ImportError:
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_flatten, tree_unflatten

def tree_leaves(pytree):
"""Torch 2.0 compatible version of tree_leaves."""
Expand Down Expand Up @@ -293,11 +293,12 @@ def check_tensor_id(name, t0, t1):

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter >= self._warmup:
tree_map(
lambda x, y: x.copy_(y, non_blocking=True),
(self._args, self._kwargs),
(args, kwargs),
)
srcs, dests = [], []
for arg_src, arg_dest in zip(
tree_leaves((args, kwargs)), self._flat_tree
):
self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests)
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
Expand All @@ -322,8 +323,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.counter += self._has_cuda
return out
else:
args, kwargs = self._args, self._kwargs = tree_map(
self._check_device_and_clone, (args, kwargs)
self._flat_tree, self._tree_spec = tree_flatten((args, kwargs))

self._flat_tree = tuple(
self._check_device_and_clone(arg) for arg in self._flat_tree
)
args, kwargs = self._args, self._kwargs = tree_unflatten(
self._flat_tree, self._tree_spec
)

torch.cuda.synchronize()
Expand Down Expand Up @@ -360,6 +366,21 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
_call_func = functools.wraps(self.module)(_call)
self._call_func = _call_func

@staticmethod
def _maybe_copy_onto_(src, dest, srcs, dests):
if isinstance(src, (torch.Tensor, TensorDictBase)):
srcs.append(src)
dests.append(dest)
try:
if src != dest:
raise ValueError("Varying inputs must be torch.Tensor subclasses.")
except Exception:
raise RuntimeError(
"Couldn't assess input value. Make sure your function only takes tensor inputs or that "
"the input value can be easily checked and is constant. For a better efficiency, avoid "
"passing non-tensor inputs to your function."
)

@classmethod
def _check_device_and_clone(cls, x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
Expand Down
13 changes: 13 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,19 @@ def test_td_input_non_tdmodule(self, compiled):
if i == 5:
assert not func._is_tensordict_module

def test_td_input_non_tdmodule_nontensor(self, compiled):
func = lambda x, y: x + y
func = self._make_cudagraph(func, compiled)
for i in range(10):
assert func(torch.zeros(()), 1.0) == 1.0
if i == 5:
assert not func._is_tensordict_module
if torch.cuda.is_available():
with pytest.raises(
ValueError, match="Varying inputs must be torch.Tensor subclasses."
):
func(torch.zeros(()), 2.0)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down

0 comments on commit 3c22416

Please sign in to comment.