Skip to content

Commit

Permalink
Closes #4130: pytest marker for python version (#4132)
Browse files Browse the repository at this point in the history
Adds pytest markers for scipy version and python version.

Closes #4130:  pytest marker for python version
Closes #4131: pytest markers for scipy version

Co-authored-by: Amanda Potts <[email protected]>
  • Loading branch information
ajpotts and ajpotts authored Feb 28, 2025
1 parent 0135d3d commit f27b31b
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 23 deletions.
9 changes: 9 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,12 @@ markers =
skip_if_nl_less_than
skip_if_nl_eq
skip_if_nl_neq
skip_if_python_version_greater_than
skip_if_python_version_less_than
skip_if_python_version_eq
skip_if_python_version_neq
skip_if_scipy_version_greater_than
skip_if_scipy_version_less_than
skip_if_scipy_version_eq
skip_if_scipy_version_neq

84 changes: 69 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import importlib
import os
import subprocess
import sys
from typing import Iterable

import pytest
import scipy

from arkouda import get_max_array_rank
from arkouda.client import get_array_ranks
Expand Down Expand Up @@ -132,12 +134,14 @@ def manage_connection():
@pytest.fixture(autouse=True)
def skip_by_rank(request):
if request.node.get_closest_marker("skip_if_max_rank_less_than"):
if request.node.get_closest_marker("skip_if_max_rank_less_than").args[0] > pytest.max_rank:
pytest.skip("this test requires server with max_array_dims >= {}".format(pytest.max_rank))
rank_requirement = request.node.get_closest_marker("skip_if_max_rank_less_than").args[0]
if pytest.max_rank < rank_requirement:
pytest.skip("this test requires server with max_array_dims >= {}".format(rank_requirement))

if request.node.get_closest_marker("skip_if_max_rank_greater_than"):
if request.node.get_closest_marker("skip_if_max_rank_greater_than").args[0] < pytest.max_rank:
pytest.skip("this test requires server with max_array_dims =< {}".format(pytest.max_rank))
rank_requirement = request.node.get_closest_marker("skip_if_max_rank_greater_than").args[0]
if pytest.max_rank > rank_requirement:
pytest.skip("this test requires server with max_array_dims <= {}".format(rank_requirement))

if request.node.get_closest_marker("skip_if_rank_not_compiled"):
ranks_requested = request.node.get_closest_marker("skip_if_rank_not_compiled").args[0]
Expand All @@ -149,9 +153,7 @@ def skip_by_rank(request):
for i in ranks_requested:
if isinstance(i, int):
if i not in array_ranks:
pytest.skip(
"this test requires server compiled with rank(s) {}".format(i)
)
pytest.skip("this test requires server compiled with rank(s) {}".format(i))
else:
raise TypeError("skip_if_rank_not_compiled only accepts type int or list of int.")
else:
Expand All @@ -161,17 +163,69 @@ def skip_by_rank(request):
@pytest.fixture(autouse=True)
def skip_by_num_locales(request):
if request.node.get_closest_marker("skip_if_nl_less_than"):
if request.node.get_closest_marker("skip_if_nl_less_than").args[0] > pytest.nl:
pytest.skip("this test requires server with nl >= {}".format(pytest.nl))
nl_requirement = request.node.get_closest_marker("skip_if_nl_less_than").args[0]
if pytest.nl < nl_requirement:
pytest.skip("this test requires server with nl <= {}".format(nl_requirement))

if request.node.get_closest_marker("skip_if_nl_greater_than"):
if request.node.get_closest_marker("skip_if_nl_greater_than").args[0] < pytest.nl:
pytest.skip("this test requires server with nl =< {}".format(pytest.nl))
nl_requirement = request.node.get_closest_marker("skip_if_nl_greater_than").args[0]
if pytest.nl > nl_requirement:
pytest.skip("this test requires server with nl <= {}".format(nl_requirement))

if request.node.get_closest_marker("skip_if_nl_eq"):
if request.node.get_closest_marker("skip_if_nl_eq").args[0] == pytest.nl:
pytest.skip("this test requires server with nl == {}".format(pytest.nl))
nl_requirement = request.node.get_closest_marker("skip_if_nl_eq").args[0]
if nl_requirement == pytest.nl:
pytest.skip("this test requires server with nl != {}".format(nl_requirement))

if request.node.get_closest_marker("skip_if_nl_neq"):
if request.node.get_closest_marker("skip_if_nl_neq").args[0] != pytest.nl:
pytest.skip("this test requires server with nl != {}".format(pytest.nl))
nl_requirement = request.node.get_closest_marker("skip_if_nl_neq").args[0]
if nl_requirement != pytest.nl:
pytest.skip("this test requires server with nl == {}".format(nl_requirement))


@pytest.fixture(autouse=True)
def skip_by_python_version(request):
if request.node.get_closest_marker("skip_if_python_version_less_than"):
python_requirement = request.node.get_closest_marker("skip_if_python_version_less_than").args[0]
if sys.version_info < python_requirement:
pytest.skip("this test requires python version >= {}".format(python_requirement))

if request.node.get_closest_marker("skip_if_python_version_greater_than"):
python_requirement = request.node.get_closest_marker("skip_if_python_version_greater_than").args[
0
]
if sys.version_info > python_requirement:
pytest.skip("this test requires python version =< {}".format(python_requirement))

if request.node.get_closest_marker("skip_if_python_version_eq"):
python_requirement = request.node.get_closest_marker("skip_if_python_version_eq").args[0]
if sys.version_info == python_requirement:
pytest.skip("this test requires python version != {}".format(python_requirement))

if request.node.get_closest_marker("skip_if_python_version_neq"):
python_requirement = request.node.get_closest_marker("skip_if_python_version_neq").args[0]
if sys.version_info != python_requirement:
pytest.skip("this test requires python version == {}".format(python_requirement))


@pytest.fixture(autouse=True)
def skip_by_scipy_version(request):
if request.node.get_closest_marker("skip_if_scipy_version_less_than"):
scipy_requirement = request.node.get_closest_marker("skip_if_scipy_version_less_than").args[0]
if scipy.__version__ < scipy_requirement:
pytest.skip("this test requires scipy version >= {}".format(scipy_requirement))

if request.node.get_closest_marker("skip_if_scipy_version_greater_than"):
scipy_requirement = request.node.get_closest_marker("skip_if_scipy_version_greater_than").args[0]
if scipy.__version__ > scipy_requirement:
pytest.skip("this test requires scipy version =< {}".format(scipy_requirement))

if request.node.get_closest_marker("skip_if_scipy_version_eq"):
scipy_requirement = request.node.get_closest_marker("skip_if_scipy_version_eq").args[0]
if scipy.__version__ == scipy_requirement:
pytest.skip("this test requires scipy version != {}".format(scipy_requirement))

if request.node.get_closest_marker("skip_if_scipy_version_neq"):
scipy_requirement = request.node.get_closest_marker("skip_if_scipy_version_neq").args[0]
if scipy.__version__ != scipy_requirement:
pytest.skip("this test requires scipy version == {}".format(scipy_requirement))
14 changes: 6 additions & 8 deletions tests/scipy/scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@
np.array([10000000, 20000000, 30000000, 40000000, 50000000, 60000000, 70000000]),
np.array([10000000, 20000000, 30000000, 40000001, 50000000, 60000000, 70000000]),
),
(np.array([10000000, 20000000, 30000000, 40000000, 50000000, 60000000, 70000000]), None),
(
np.array([10000000, 20000000, 30000000, 40000000, 50000000, 60000000, 70000000]),
None,
),
(np.array([44, 24, 29, 3]) / 100 * 189, np.array([43, 52, 54, 40])),
]


class TestStats:

@classmethod
def setup_class(cls):
import sys
# skip tests with 3.13+
if sys.version_info >= (3, 13):
pytest.skip("scipy tests do not work yet with Python 3.13+")

@pytest.mark.skip_if_scipy_version_greater_than("1.13.1")
@pytest.mark.parametrize(
"lambda_",
[
Expand All @@ -50,6 +47,7 @@ def test_power_divergence(self, lambda_, ddof, pair):

assert np.allclose(ak_power_div, scipy_power_div, equal_nan=True)

@pytest.mark.skip_if_scipy_version_greater_than("1.13.1")
@pytest.mark.parametrize("ddof", DDOF)
@pytest.mark.parametrize("pair", PAIRS)
def test_chisquare(self, ddof, pair):
Expand Down

0 comments on commit f27b31b

Please sign in to comment.