Skip to content

Commit

Permalink
Skip test if import missing
Browse files Browse the repository at this point in the history
  • Loading branch information
parsiad committed Aug 23, 2024
1 parent 00e7cf1 commit 8c63920
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tests/test_func.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import numpy as np
import scipy.special
import pytest

import micrograd_pp as mpp


@pytest.mark.skipif(not pytest.importorskip("scipy.special"), reason="Unable to import scipy.special")
def test_softmax() -> None:
import scipy.special

a = np.random.randn(5, 4, 3)
actual = mpp.softmax(mpp.Constant(a), dim=1).value
desired = scipy.special.softmax(a, axis=1)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
import torch

import micrograd_pp as mpp

Expand Down Expand Up @@ -88,7 +87,10 @@ def test_layer_norm() -> None:


@pytest.mark.parametrize("is_causal", (False, True))
@pytest.mark.skipif(not pytest.importorskip("torch"), reason="Unable to import torch")
def test_multihead_attention(is_causal: bool) -> None: # Test against PyTorch implementation
import torch

torch_attn_mask = None
mpp_attn_mask = None
if is_causal:
Expand Down

0 comments on commit 8c63920

Please sign in to comment.