Skip to content

Commit

Permalink
[Refactor, Tests] Move TestCudagraphs
Browse files Browse the repository at this point in the history
ghstack-source-id: ee69a57d5783c0e79d84354dacb14824e0e06418
Pull Request resolved: #1007
  • Loading branch information
vmoens committed Sep 23, 2024
1 parent 31785fa commit d9fece7
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 191 deletions.
198 changes: 194 additions & 4 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,34 @@
import argparse
import contextlib
import importlib.util
import inspect
import os
from pathlib import Path
from typing import Any
from typing import Any, Callable

import pytest

import torch
from packaging import version

from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch.utils._pytree import tree_map
from tensordict import (
assert_close,
PYTREE_REGISTERED_LAZY_TDS,
PYTREE_REGISTERED_TDS,
tensorclass,
TensorDict,
TensorDictParams,
)
from tensordict.nn import (
CudaGraphModule,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)

from tensordict.nn.functional_modules import _exclude_td_from_pytree

from torch.utils._pytree import SUPPORTED_NODES, tree_map

TORCH_VERSION = version.parse(torch.__version__).base_version

Expand Down Expand Up @@ -871,3 +887,177 @@ def to_numpy(tensor):
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)


@pytest.mark.skipif(TORCH_VERSION <= "2.4.1", reason="requires torch>=2.5")
@pytest.mark.parametrize("compiled", [False, True])
class TestCudaGraphs:
@pytest.fixture(scope="class", autouse=True)
def _set_cuda_device(self):
device = torch.get_default_device()
do_unset = False
for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS:
if tdtype in SUPPORTED_NODES:
do_unset = True
excluder = _exclude_td_from_pytree()
excluder.set()
break
if torch.cuda.is_available():
torch.set_default_device("cuda:0")
yield
if do_unset:
excluder.unset()
torch.set_default_device(device)

def test_cudagraphs_random(self, compiled):
def func(x):
return x + torch.randn_like(x)

if compiled:
func = torch.compile(func)

with (
pytest.warns(UserWarning)
if not torch.cuda.is_available()
else contextlib.nullcontext()
):
func = CudaGraphModule(func)

x = torch.randn(10)
for _ in range(10):
func(x)
assert isinstance(func(torch.zeros(10)), torch.Tensor)
assert (func(torch.zeros(10)) != 0).any()
y0 = func(x)
y1 = func(x + 1)
with pytest.raises(AssertionError):
torch.testing.assert_close(y0, y1 + 1)

@staticmethod
def _make_cudagraph(
func: Callable, compiled: bool, *args, **kwargs
) -> CudaGraphModule:
if compiled:
func = torch.compile(func)
with (
pytest.warns(UserWarning)
if not torch.cuda.is_available()
else contextlib.nullcontext()
):
func = CudaGraphModule(func, *args, **kwargs)
return func

@staticmethod
def check_types(func, *args, **kwargs):
signature = inspect.signature(func)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
for param_name, param in signature.parameters.items():
arg_value = bound_args.arguments[param_name]
if param.annotation != param.empty:
if not isinstance(arg_value, param.annotation):
raise TypeError(
f"Argument '{param_name}' should be of type {param.annotation}, but is of type {type(arg_value)}"
)

def test_signature(self, compiled):
if compiled:
pytest.skip()

def func(x: torch.Tensor):
return x + torch.randn_like(x)

with pytest.raises(TypeError):
self.check_types(func, "a string")
self.check_types(func, torch.ones(()))

def test_backprop(self, compiled):
x = torch.nn.Parameter(torch.ones(3))
y = torch.nn.Parameter(torch.ones(3))
optimizer = torch.optim.SGD([x, y], lr=1)

def func():
optimizer.zero_grad()
z = x + y
z = z.sum()
z.backward()
optimizer.step()

func = self._make_cudagraph(func, compiled, warmup=4)

for i in range(1, 11):
torch.compiler.cudagraph_mark_step_begin()
func()

assert (x == 1 - i).all(), i
assert (y == 1 - i).all(), i
# assert (x.grad == 1).all()
# assert (y.grad == 1).all()

def test_tdmodule(self, compiled):
tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"])
tdmodule = self._make_cudagraph(tdmodule, compiled)
assert tdmodule._is_tensordict_module
for i in range(10):
td = TensorDict(x=torch.randn(()))
tdmodule(td)
assert td["y"] == td["x"] + 1, i

tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"])
tdmodule = self._make_cudagraph(tdmodule, compiled)
assert tdmodule._is_tensordict_module
for _ in range(10):
x = torch.randn(())
y = tdmodule(x=x)
assert y == x + 1

tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"])
tdmodule = self._make_cudagraph(tdmodule, compiled)
assert tdmodule._is_tensordict_module
for _ in range(10):
td = TensorDict(x=torch.randn(()))
tdout = TensorDict()
tdmodule(td, tensordict_out=tdout)
assert tdout is not td
assert "x" not in tdout
assert tdout["y"] == td["x"] + 1

tdmodule = lambda td: td.set("y", td.get("x") + 1)
tdmodule = self._make_cudagraph(tdmodule, compiled, in_keys=[], out_keys=[])
assert tdmodule._is_tensordict_module
for i in range(10):
td = TensorDict(x=torch.randn(()))
tdmodule(td)
assert tdmodule._out_matches_in
if i >= tdmodule._warmup and torch.cuda.is_available():
assert tdmodule._selected_keys == ["y"]
assert td["y"] == td["x"] + 1

tdmodule = lambda td: td.set("y", td.get("x") + 1)
tdmodule = self._make_cudagraph(
tdmodule, compiled, in_keys=["x"], out_keys=["y"]
)
assert tdmodule._is_tensordict_module
for _ in range(10):
td = TensorDict(x=torch.randn(()))
tdmodule(td)
assert td["y"] == td["x"] + 1

tdmodule = lambda td: td.copy().set("y", td.get("x") + 1)
tdmodule = self._make_cudagraph(tdmodule, compiled, in_keys=[], out_keys=[])
assert tdmodule._is_tensordict_module
for _ in range(10):
td = TensorDict(x=torch.randn(()))
tdout = tdmodule(td)
assert tdout is not td
assert "y" not in td
assert tdout["y"] == td["x"] + 1

def test_td_input_non_tdmodule(self, compiled):
func = lambda x: x + 1
func = self._make_cudagraph(func, compiled)
for i in range(10):
td = TensorDict(a=1)
func(td)
if i == 5:
assert not func._is_tensordict_module
Loading

2 comments on commit d9fece7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: d9fece7 Previous: 31785fa Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 76961.40665040359 iter/sec (stddev: 9.413471238572e-7) 214022.94814806702 iter/sec (stddev: 3.608165284772366e-7) 2.78
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 76718.07989758631 iter/sec (stddev: 8.094116488218107e-7) 212346.19259142174 iter/sec (stddev: 4.891365462985409e-7) 2.77

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: d9fece7 Previous: 31785fa Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 125860.81177149409 iter/sec (stddev: 5.667433370796047e-7) 317999.4159737515 iter/sec (stddev: 3.3553422502450923e-7) 2.53
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 127885.41763992855 iter/sec (stddev: 5.644753373818858e-7) 324338.8144811273 iter/sec (stddev: 3.3023242009132323e-7) 2.54

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.