Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lennybronner committed Sep 25, 2024
1 parent 835a3f6 commit fd442b3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import pandas as pd
import numpy as np
import pytest

_TEST_FOLDER = os.path.dirname(__file__)
Expand Down Expand Up @@ -40,3 +41,8 @@ def random_data_no_weights(get_fixture):
@pytest.fixture(scope="session")
def random_data_weights(get_fixture):
return get_fixture("random_data_n100_p5_12549_weights.csv")

@pytest.fixture(scope="session")
def rng():
seed = 8232
return np.random.default_rng(seed=seed)
20 changes: 19 additions & 1 deletion tests/test_linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,27 @@
import pytest

from elexsolver.LinearSolver import LinearSolver

from elexsolver.QuantileRegressionSolver import QuantileRegressionSolver

def test_fit():
solver = LinearSolver()
with pytest.raises(NotImplementedError):
solver.fit(np.ndarray((5, 3)), np.ndarray((1, 3)))


##################
# Test residuals #
##################
def test_residuals_without_weights(rng):
x = rng.normal(size=(100, 5))
beta = rng.normal(size=(5, 1))
y = x @ beta

# we need an a subclass of LinearSolver to actually run a fit
reg = QuantileRegressionSolver()
reg.fit(x, y, fit_intercept=False)
reg.predict(x)

residuals_train = reg.residuals(x, y, K=None, center=False)
residuals_K = reg.residuals(x, y, K=10, center=False)
import pdb; pdb.set_trace()
14 changes: 14 additions & 0 deletions tests/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def test_basic2():
preds = lm.predict(x)
assert all(np.abs(preds - [6.666667, 6.666667, 6.666667, 15]) <= TOL)

def test_cache():
lm = OLSRegressionSolver()
x = np.asarray([[1, 1], [1, 1], [1, 1], [1, 2]])
y = np.asarray([3, 8, 9, 15])
lm.fit(x, y, fit_intercept=True, cache=False)

assert lm.normal_eqs is None
assert lm.hat_vals is None

lm.fit(x, y, fit_intercept=True, cache=True)

assert lm.normal_eqs is not None
assert lm.hat_vals is not None
assert lm.coefficients is not None

######################
# Intermediate tests #
Expand Down
11 changes: 11 additions & 0 deletions tests/test_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def test_basic_upper():
preds = quantreg.predict(x)
np.testing.assert_array_equal(preds, [[9], [9], [9], [15]])

def test_cache():
quantreg = QuantileRegressionSolver()
tau = 0.9
x = np.asarray([[1, 1], [1, 1], [1, 1], [1, 2]])
y = np.asarray([3, 8, 9, 15])
quantreg.fit(x, y, tau, cache=False)

assert quantreg.coefficients == []

quantreg.fit(x, y, tau, cache=True)
assert len(quantreg.coefficients) > 0

######################
# Intermediate tests #
Expand Down

0 comments on commit fd442b3

Please sign in to comment.