Skip to content

Commit

Permalink
Merge pull request #668 from danieldk/backport/663
Browse files Browse the repository at this point in the history
Backport #663 to v8.0.x
  • Loading branch information
adrianeboyd authored May 18, 2022
2 parents a9498a2 + 0123816 commit 7db5ed6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion thinc/tests/layers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from thinc.backends import context_pools
from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler
from thinc.util import has_torch, has_torch_amp, has_torch_gpu
from thinc.util import has_cupy
import numpy
import pytest

Expand Down Expand Up @@ -63,7 +64,7 @@ def test_pytorch_wrapper(nN, nI, nO):
assert isinstance(model.predict(X), numpy.ndarray)


@pytest.mark.skipif(not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU")
@pytest.mark.skipif(not has_cupy or not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU")
@pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)])
@pytest.mark.parametrize("mixed_precision", TORCH_MIXED_PRECISION)
def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision):
Expand Down

0 comments on commit 7db5ed6

Please sign in to comment.