Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
lennybronner committed Sep 25, 2024
1 parent fd442b3 commit db3c80b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import sys

import pandas as pd
import numpy as np
import pandas as pd
import pytest

_TEST_FOLDER = os.path.dirname(__file__)
Expand Down Expand Up @@ -42,7 +42,8 @@ def random_data_no_weights(get_fixture):
def random_data_weights(get_fixture):
return get_fixture("random_data_n100_p5_12549_weights.csv")


@pytest.fixture(scope="session")
def rng():
seed = 8232
return np.random.default_rng(seed=seed)
return np.random.default_rng(seed=seed)
6 changes: 3 additions & 3 deletions tests/test_linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from elexsolver.LinearSolver import LinearSolver
from elexsolver.QuantileRegressionSolver import QuantileRegressionSolver


def test_fit():
solver = LinearSolver()
with pytest.raises(NotImplementedError):
Expand All @@ -23,6 +24,5 @@ def test_residuals_without_weights(rng):
reg.fit(x, y, fit_intercept=False)
reg.predict(x)

residuals_train = reg.residuals(x, y, K=None, center=False)
residuals_K = reg.residuals(x, y, K=10, center=False)
import pdb; pdb.set_trace()
reg.residuals(x, y, K=None, center=False)
reg.residuals(x, y, K=10, center=False)
2 changes: 2 additions & 0 deletions tests/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_basic2():
preds = lm.predict(x)
assert all(np.abs(preds - [6.666667, 6.666667, 6.666667, 15]) <= TOL)


def test_cache():
lm = OLSRegressionSolver()
x = np.asarray([[1, 1], [1, 1], [1, 1], [1, 2]])
Expand All @@ -45,6 +46,7 @@ def test_cache():
assert lm.hat_vals is not None
assert lm.coefficients is not None


######################
# Intermediate tests #
######################
Expand Down
2 changes: 2 additions & 0 deletions tests/test_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_basic_upper():
preds = quantreg.predict(x)
np.testing.assert_array_equal(preds, [[9], [9], [9], [15]])


def test_cache():
quantreg = QuantileRegressionSolver()
tau = 0.9
Expand All @@ -67,6 +68,7 @@ def test_cache():
quantreg.fit(x, y, tau, cache=True)
assert len(quantreg.coefficients) > 0


######################
# Intermediate tests #
######################
Expand Down

0 comments on commit db3c80b

Please sign in to comment.