From db3c80bd50616d177eaa15709f0713068d8a72b9 Mon Sep 17 00:00:00 2001 From: lennybronner Date: Tue, 24 Sep 2024 22:23:23 -0400 Subject: [PATCH] linter --- tests/conftest.py | 5 +++-- tests/test_linear_solver.py | 6 +++--- tests/test_ols.py | 2 ++ tests/test_quantile.py | 2 ++ 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0271ca6..7c81115 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,8 @@ import os import sys -import pandas as pd import numpy as np +import pandas as pd import pytest _TEST_FOLDER = os.path.dirname(__file__) @@ -42,7 +42,8 @@ def random_data_no_weights(get_fixture): def random_data_weights(get_fixture): return get_fixture("random_data_n100_p5_12549_weights.csv") + @pytest.fixture(scope="session") def rng(): seed = 8232 - return np.random.default_rng(seed=seed) \ No newline at end of file + return np.random.default_rng(seed=seed) diff --git a/tests/test_linear_solver.py b/tests/test_linear_solver.py index 4cc06d8..f0994a3 100644 --- a/tests/test_linear_solver.py +++ b/tests/test_linear_solver.py @@ -4,6 +4,7 @@ from elexsolver.LinearSolver import LinearSolver from elexsolver.QuantileRegressionSolver import QuantileRegressionSolver + def test_fit(): solver = LinearSolver() with pytest.raises(NotImplementedError): @@ -23,6 +24,5 @@ def test_residuals_without_weights(rng): reg.fit(x, y, fit_intercept=False) reg.predict(x) - residuals_train = reg.residuals(x, y, K=None, center=False) - residuals_K = reg.residuals(x, y, K=10, center=False) - import pdb; pdb.set_trace() \ No newline at end of file + reg.residuals(x, y, K=None, center=False) + reg.residuals(x, y, K=10, center=False) diff --git a/tests/test_ols.py b/tests/test_ols.py index 875037e..66adf62 100644 --- a/tests/test_ols.py +++ b/tests/test_ols.py @@ -30,6 +30,7 @@ def test_basic2(): preds = lm.predict(x) assert all(np.abs(preds - [6.666667, 6.666667, 6.666667, 15]) <= TOL) + def test_cache(): lm = OLSRegressionSolver() x = np.asarray([[1, 1], [1, 1], [1, 1], [1, 2]]) @@ -45,6 +46,7 @@ def test_cache(): assert lm.hat_vals is not None assert lm.coefficients is not None + ###################### # Intermediate tests # ###################### diff --git a/tests/test_quantile.py b/tests/test_quantile.py index 7738501..c39e805 100644 --- a/tests/test_quantile.py +++ b/tests/test_quantile.py @@ -55,6 +55,7 @@ def test_basic_upper(): preds = quantreg.predict(x) np.testing.assert_array_equal(preds, [[9], [9], [9], [15]]) + def test_cache(): quantreg = QuantileRegressionSolver() tau = 0.9 @@ -67,6 +68,7 @@ def test_cache(): quantreg.fit(x, y, tau, cache=True) assert len(quantreg.coefficients) > 0 + ###################### # Intermediate tests # ######################