Skip to content

Commit

Permalink
Changed test case for LinearRegression model; testing the predicted…
Browse files Browse the repository at this point in the history
… output precision upto `1` decimal only

Although scikit-learn api tests for upto 8 decimal points; keeping in mind that
`ai` is not supposed to be a competitor and is only used for learning purpose
we test for upto `1` decimal point.

Signed-off-by: Ayush Joshi <[email protected]>
  • Loading branch information
joshiayush committed Dec 7, 2023
1 parent b65ae16 commit 04824a4
Showing 1 changed file with 18 additions and 25 deletions.
43 changes: 18 additions & 25 deletions tests/test_linear_model/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,29 @@

import numpy as np

from ai.linear_model import LinearRegression
from numpy.testing import (
assert_array_almost_equal,
)

from ai.linear_model import LinearRegression

def test_linear_regression_fit():
lr = LinearRegression()
X = np.array([[1.], [2.], [3.]])
y = np.array([2., 4., 6.])
lr.fit(X, y)
assert lr._is_fitted

def test_linear_regression():
# Test LinearRegression on a simple dataset; a simple dataset
X = np.array([[1], [2]])
y = np.array([1, 2])

def test_linear_regression_predict():
lr = LinearRegression()
X = np.array([[1.], [2.], [3.]])
y = np.array([2., 4., 6.])
lr.fit(X, y)
y_pred = lr.predict(X) # Predicting on the `X` sample vector for now
assert np.allclose(y_pred, y, rtol=1111111e-7)
model = LinearRegression()
model.fit(X, y)

# although scikit-learn api tests for upto 8 decimal points we only test for 1
assert_array_almost_equal(model.predict(X), [1.0, 2.0], decimal=1)

def test_linear_regression_predict_before_fit():
lr = LinearRegression()
X = np.array([[1], [2], [3]])
with pytest.raises(RuntimeError):
lr.predict(X)
# also testing for degenerate input
X = np.array([[1]])
y = [0]

model = LinearRegression()
model.fit(X, y)

def test_linear_regression_fit_shape_mismatch():
lr = LinearRegression()
X = np.array([[1], [2], [3]])
y = np.array([2, 4])
with pytest.raises(ValueError):
lr.fit(X, y)
assert_array_almost_equal(model.predict(X), [0], decimal=1)

0 comments on commit 04824a4

Please sign in to comment.