diff --git a/tests/test_linear_model/test_linear.py b/tests/test_linear_model/test_linear.py index 0f835e5..04345f1 100644 --- a/tests/test_linear_model/test_linear.py +++ b/tests/test_linear_model/test_linear.py @@ -16,36 +16,29 @@ import numpy as np -from ai.linear_model import LinearRegression +from numpy.testing import ( + assert_array_almost_equal, +) +from ai.linear_model import LinearRegression -def test_linear_regression_fit(): - lr = LinearRegression() - X = np.array([[1.], [2.], [3.]]) - y = np.array([2., 4., 6.]) - lr.fit(X, y) - assert lr._is_fitted +def test_linear_regression(): + # Test LinearRegression on a simple dataset; a simple dataset + X = np.array([[1], [2]]) + y = np.array([1, 2]) -def test_linear_regression_predict(): - lr = LinearRegression() - X = np.array([[1.], [2.], [3.]]) - y = np.array([2., 4., 6.]) - lr.fit(X, y) - y_pred = lr.predict(X) # Predicting on the `X` sample vector for now - assert np.allclose(y_pred, y, rtol=1111111e-7) + model = LinearRegression() + model.fit(X, y) + # although scikit-learn api tests for upto 8 decimal points we only test for 1 + assert_array_almost_equal(model.predict(X), [1.0, 2.0], decimal=1) -def test_linear_regression_predict_before_fit(): - lr = LinearRegression() - X = np.array([[1], [2], [3]]) - with pytest.raises(RuntimeError): - lr.predict(X) + # also testing for degenerate input + X = np.array([[1]]) + y = [0] + model = LinearRegression() + model.fit(X, y) -def test_linear_regression_fit_shape_mismatch(): - lr = LinearRegression() - X = np.array([[1], [2], [3]]) - y = np.array([2, 4]) - with pytest.raises(ValueError): - lr.fit(X, y) + assert_array_almost_equal(model.predict(X), [0], decimal=1)