From b6c425f033252efbe3df46b073ca1658cfffe482 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 09:32:47 -0700 Subject: [PATCH] parameterize tests for cuda devices currently failing a few of them, some are expected failures. --- tests/cases/torch_train.py | 77 +++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 2c67fda0..b368f9d6 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -68,8 +68,20 @@ def forward(self, a, b): return d_pred +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_loss_drops(tmpdir): +def test_loss_drops(tmpdir, device): checkpoint_basename = str(tmpdir / "model") a_key = ArrayKey("A") @@ -80,7 +92,7 @@ def test_loss_drops(tmpdir): model = ExampleLinearModel() loss = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.999) source = example_train_source(a_key, b_key, c_key) train = Train( @@ -98,6 +110,7 @@ def test_loss_drops(tmpdir): checkpoint_basename=checkpoint_basename, save_every=100, spawn_subprocess=False, + device=device, ) pipeline = source + train @@ -130,8 +143,25 @@ def test_loss_drops(tmpdir): assert loss2 < loss1 +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device when using a subprocess" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_output(): +def test_output(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -153,6 +183,7 @@ def test_output(): d_pred: ArraySpec(nonspatial=True), }, spawn_subprocess=True, + device=device, ) pipeline = source + predict @@ -191,8 +222,25 @@ def forward(self, a): return pred +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_scan(): +def test_scan(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -210,6 +258,7 @@ def test_scan(): inputs={"a": a_key}, outputs={0: pred}, array_specs={pred: ArraySpec()}, + device=device, ) pipeline = source + predict + Scan(reference_request, num_workers=2) @@ -226,8 +275,25 @@ def test_scan(): assert pred in batch +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_precache(): +def test_precache(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -245,6 +311,7 @@ def test_precache(): inputs={"a": a_key}, outputs={0: pred}, array_specs={pred: ArraySpec()}, + device=device, ) pipeline = source + predict + PreCache(cache_size=3, num_workers=2)