diff --git a/ai/linear_model/logistic.py b/ai/linear_model/logistic.py index 7ef555f..9587c77 100644 --- a/ai/linear_model/logistic.py +++ b/ai/linear_model/logistic.py @@ -137,7 +137,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> 'LogisticRegression': """ n_samples, n_features = X.shape self._bias = 0 - self._weights = np.zeros(X.shape[1]) + self._weights = np.zeros(n_features) for _ in range(self._n_iters): y_pred = self._sigmoid(np.dot(X, self._weights) + self._bias) @@ -179,4 +179,4 @@ def predict(self, X: np.ndarray) -> np.ndarray: y_pred = self._sigmoid(np.dot(X, self._weights) + self._bias) c_pred = [0 if y <= .5 else 1 for y in y_pred] - return c_pred \ No newline at end of file + return np.array(c_pred) \ No newline at end of file diff --git a/tests/test_linear_model/test_logistic.py b/tests/test_linear_model/test_logistic.py new file mode 100644 index 0000000..ee3ba94 --- /dev/null +++ b/tests/test_linear_model/test_logistic.py @@ -0,0 +1,74 @@ +# Copyright 2023 The AI Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=too-many-function-args, invalid-name, missing-module-docstring +# pylint: disable=missing-class-docstring + +import pytest +import numpy as np + +from ai.linear_model import LogisticRegression + + +def test_predict_simple(): + model = LogisticRegression() + model._weights = np.array([0.5, 0.5]) + model._bias = 0.1 + + X = np.array([[1, 2]]) + predicted = model.predict(X) + assert isinstance(predicted, np.ndarray) + assert len(predicted) == 1 + assert predicted[0] in [0, 1] + + +def test_predict_multiple_samples(): + model = LogisticRegression() + model._weights = np.array([0.5, 0.5]) + model._bias = 0.1 + + X = np.array([[1, 2], [2, 3], [3, 4]]) + predicted = model.predict(X) + assert isinstance(predicted, np.ndarray) + assert len(predicted) == 3 + assert all(label in [0, 1] for label in predicted) + + +def test_predict_different_weights(): + model = LogisticRegression() + model._weights = np.array([0.2, 0.8]) + model._bias = -0.1 + + X = np.array([[1, 2]]) + predicted = model.predict(X) + assert isinstance(predicted, np.ndarray) + assert len(predicted) == 1 + assert predicted[0] in [0, 1] + + +def test_predict_before_fit(): + model = LogisticRegression() + X = np.array([[1, 2]]) + + with pytest.raises(RuntimeError): + model.predict(X) + + +def test_predict_shape_mismatch(): + model = LogisticRegression() + model._weights = np.array([0.5, 0.5]) + model._bias = 0.1 + + X = np.array([[1, 2, 3]]) + with pytest.raises(ValueError): + model.predict(X)