Skip to content

Commit

Permalink
add separate unit and regression tests for different sklearn versions
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonBurns committed Jul 1, 2023
1 parent 6e24b1a commit a06c0a5
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 3 deletions.
Binary file not shown.
112 changes: 110 additions & 2 deletions test/regression/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime

import numpy as np
import pkg_resources

from astartes import train_val_test_split
from astartes.samplers import (
Expand All @@ -12,6 +13,17 @@
IMPLEMENTED_INTERPOLATION_SAMPLERS,
)

SKLEARN_GEQ_13 = ( # get the sklearn version
int(
pkg_resources.get_distribution(
"scikit-learn",
).version.split(
"."
)[1]
)
>= 3
)


class Test_regression(unittest.TestCase):
"""
Expand All @@ -34,8 +46,20 @@ def setUpClass(self):
self.reference_splits = {
name: os.path.join(self.reference_splits_dir, name + "_reference.pkl")
for name in ALL_SAMPLERS
if name not in ("scaffold",)
if name
not in (
"scaffold",
"kmeans",
)
}
self.reference_splits["kmeans-v1.3"] = os.path.join(
self.reference_splits_dir,
"kmeans-v1.3_reference.pkl",
)
self.reference_splits["kmeans-v1.2.2"] = os.path.join(
self.reference_splits_dir,
"kmeans-v1.2.2_reference.pkl",
)

def test_timebased_regression(self):
"""Regression test TimeBased, which has labels to check as well."""
Expand Down Expand Up @@ -94,7 +118,7 @@ def test_interpolation_regression(self):
def test_extrapolation_regression(self):
"""Regression testing of extrapolative methods relative to static results."""
for sampler_name in IMPLEMENTED_EXTRAPOLATION_SAMPLERS:
if sampler_name in ("scaffold", "time_based"):
if sampler_name in ("scaffold", "time_based", "kmeans"):
continue
(
X_train,
Expand Down Expand Up @@ -133,6 +157,90 @@ def test_extrapolation_regression(self):
"Sampler {:s} failed regression testing.".format(sampler_name),
)

@unittest.skipUnless(
SKLEARN_GEQ_13,
"sklearn version less than 1.3 detected",
)
def test_kmeans_regression_sklearn_v13(self):
"""Regression testing of KMeans in sklearn v1.3 or newer."""
(
X_train,
X_val,
X_test,
y_train,
y_val,
y_test,
clusters_train,
clusters_val,
clusters_test,
) = train_val_test_split(
self.X,
self.y,
sampler="kmeans",
random_state=42,
)
all_output = [
X_train,
X_val,
X_test,
y_train,
y_val,
y_test,
clusters_train,
clusters_val,
clusters_test,
]
with open(self.reference_splits["kmeans-v1.3"], "rb") as f:
reference_output = pkl.load(f)
for i, j in zip(all_output, reference_output):
np.testing.assert_array_equal(
i,
j,
"Sampler kmeans failed regression testing.",
)

@unittest.skipIf(
SKLEARN_GEQ_13,
"sklearn version 1.3 or newer detected",
)
def test_kmeans_regression_sklearn_v12(self):
"""Regression testing of KMeans in sklearn v1.2 or earlier."""
(
X_train,
X_val,
X_test,
y_train,
y_val,
y_test,
clusters_train,
clusters_val,
clusters_test,
) = train_val_test_split(
self.X,
self.y,
sampler="kmeans",
random_state=42,
)
all_output = [
X_train,
X_val,
X_test,
y_train,
y_val,
y_test,
clusters_train,
clusters_val,
clusters_test,
]
with open(self.reference_splits["kmeans-v1.2.2"], "rb") as f:
reference_output = pkl.load(f)
for i, j in zip(all_output, reference_output):
np.testing.assert_array_equal(
i,
j,
"Sampler kmeans failed regression testing.",
)


if __name__ == "__main__":
unittest.main()
104 changes: 103 additions & 1 deletion test/unit/samplers/extrapolative/test_kmeans.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import unittest

import numpy as np
import pkg_resources

from astartes import train_test_split
from astartes.samplers import KMeans
from astartes.utils.warnings import ImperfectSplittingWarning

SKLEARN_GEQ_13 = ( # get the sklearn version
int(
pkg_resources.get_distribution(
"scikit-learn",
).version.split(
"."
)[1]
)
>= 3
)


class Test_kmeans(unittest.TestCase):
"""
Expand Down Expand Up @@ -35,7 +47,11 @@ def setUpClass(self):
]
)

def test_kmeans_sampling(self):
@unittest.skipIf(
SKLEARN_GEQ_13,
"sklearn version 1.3 or newer detected",
)
def test_kmeans_sampling_v12(self):
"""Use kmeans in the train_test_split and verify results."""
with self.assertWarns(ImperfectSplittingWarning):
(
Expand Down Expand Up @@ -117,6 +133,92 @@ def test_kmeans_sampling(self):
"Test clusters incorrect.",
)

@unittest.skipUnless(
SKLEARN_GEQ_13,
"sklearn version less than 1.3 detected",
)
def test_kmeans_sampling_v13(self):
"""Use kmeans in the train_test_split and verify results."""
with self.assertWarns(ImperfectSplittingWarning):
(
X_train,
X_test,
y_train,
y_test,
labels_train,
labels_test,
clusters_train,
clusters_test,
) = train_test_split(
self.X,
self.y,
labels=self.labels,
test_size=0.75,
train_size=0.25,
sampler="kmeans",
random_state=42,
hopts={
"n_clusters": 2,
},
)
# test that the known arrays equal the result from above
self.assertIsNone(
np.testing.assert_array_equal(
X_train,
np.array([[0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 0, 0, 0]]),
),
"Train X incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
X_test,
np.array([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]),
),
"Test X incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
y_train,
np.array([1, 2, 3]),
),
"Train y incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
y_test,
np.array([4, 5]),
),
"Test y incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
labels_train,
np.array(["one", "two", "three"]),
),
"Train labels incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
labels_test,
np.array(["four", "five"]),
),
"Test labels incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
clusters_train,
np.array([0, 0, 0]),
),
"Train clusters incorrect.",
)
self.assertIsNone(
np.testing.assert_array_equal(
clusters_test,
np.array([1, 1]),
),
"Test clusters incorrect.",
)

def test_kmeans(self):
"""Directly instantiate and test KMeans."""
kmeans_instance = KMeans(
Expand Down

0 comments on commit a06c0a5

Please sign in to comment.