diff --git a/mab2rec/config.py b/mab2rec/config.py index a53d19d..9cb951a 100644 --- a/mab2rec/config.py +++ b/mab2rec/config.py @@ -40,9 +40,14 @@ class LinGreedy: The regularization strength. Integer or float. Must be greater than zero. Default value is 1.0. + scale: bool + Whether to scale features to have zero mean and unit variance. + Uses StandardScaler in sklearn.preprocessing. + Default value is False. """ epsilon: float = 0.1 l2_lambda: float = 1.0 + scale: bool = False @spock @@ -59,9 +64,14 @@ class LinTS: The regularization strength. Integer or float. Must be greater than zero. Default value is 1.0. + scale: bool + Whether to scale features to have zero mean and unit variance. + Uses StandardScaler in sklearn.preprocessing. + Default value is False. """ alpha: float = 1.0 l2_lambda: float = 1.0 + scale: bool = False @spock @@ -78,9 +88,14 @@ class LinUCB: The regularization strength. Integer or float. Cannot be negative. Default value is 1.0. + scale: bool + Whether to scale features to have zero mean and unit variance. + Uses StandardScaler in sklearn.preprocessing. + Default value is False. """ alpha: float = 1.0 l2_lambda: float = 1.0 + scale: bool = False @spock @@ -283,11 +298,11 @@ def init_recommender(config): if isinstance(lp_params, EpsilonGreedy): lp = LearningPolicy.EpsilonGreedy(epsilon=lp_params.epsilon) elif isinstance(lp_params, LinGreedy): - lp = LearningPolicy.LinGreedy(epsilon=lp_params.epsilon, l2_lambda=lp_params.l2_lambda) + lp = LearningPolicy.LinGreedy(epsilon=lp_params.epsilon, l2_lambda=lp_params.l2_lambda, scale=lp_params.scale) elif isinstance(lp_params, LinTS): - lp = LearningPolicy.LinTS(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda) + lp = LearningPolicy.LinTS(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda, scale=lp_params.scale) elif isinstance(lp_params, LinUCB): - lp = LearningPolicy.LinUCB(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda) + lp = LearningPolicy.LinUCB(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda, scale=lp_params.scale) elif isinstance(lp_params, Popularity): lp = LearningPolicy.Popularity() elif isinstance(lp_params, Random): diff --git a/tests/test_rec.py b/tests/test_rec.py index 8917c6d..34e1e16 100644 --- a/tests/test_rec.py +++ b/tests/test_rec.py @@ -612,8 +612,8 @@ def test_recommend_lin_greedy(self): top_k=2, seed=123456) self.assertEqual(results[0], [[3, 1], [1, 3]]) - self.assertListAlmostEqual(results[1], [[0.6504125435586658, 0.5240639631785098], - [0.7114885562660387, 0.5354020923417662]]) + self.assertListAlmostEqual(results[1][0], [0.6504125435586658, 0.5240639631785098]) + self.assertListAlmostEqual(results[1][1], [0.7114885562660387, 0.5354020923417662]) # No scores results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False) @@ -633,8 +633,8 @@ def test_recommend_lin_ucb(self): top_k=2, seed=123456) self.assertEqual(results[0], [[3, 2], [1, 3]]) - self.assertListAlmostEqual(results[1], [[0.8355754378823774, 0.8103388262282213], - [0.8510415343853225, 0.8454457789037026]]) + self.assertListAlmostEqual(results[1][0], [0.8355754378823774, 0.8103388262282213]) + self.assertListAlmostEqual(results[1][1], [0.8510415343853225, 0.8454457789037026]) # No scores results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False) @@ -654,8 +654,8 @@ def test_recommend_lin_ts(self): top_k=2, seed=123456) self.assertEqual(results[0], [[2, 3], [3, 1]]) - self.assertListAlmostEqual(results[1], [[0.9571299309237765, 0.7351400505873965], - [0.8548770839301622, 0.7726819822665895]]) + self.assertListAlmostEqual(results[1][0], [0.9571299309237765, 0.7351400505873965]) + self.assertListAlmostEqual(results[1][1], [0.8548770839301622, 0.7726819822665895]) # No scores results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False) @@ -675,8 +675,8 @@ def test_recommend_clusters_ts(self): top_k=2, seed=123456) self.assertEqual(results[0], [[2, 1], [1, 3]]) - self.assertListAlmostEqual(results[1], [[0.6470729583134509, 0.6239486262002204], - [0.7257397617770284, 0.6902019029795886]]) + self.assertListAlmostEqual(results[1][0], [0.6470729583134509, 0.6239486262002204]) + self.assertListAlmostEqual(results[1][1], [0.7257397617770284, 0.6902019029795886]) # No scores results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False) @@ -696,8 +696,8 @@ def test_recommend_radius_ts(self): top_k=2, seed=123456) self.assertEqual(results[0], [[1, 3], [3, 2]]) - self.assertListAlmostEqual(results[1], [[0.6853064650518793, 0.5794087793326232], - [0.6171485591737581, 0.6039485772665535]]) + self.assertListAlmostEqual(results[1][0], [0.6853064650518793, 0.5794087793326232]) + self.assertListAlmostEqual(results[1][1], [0.6171485591737581, 0.6039485772665535]) # No scores results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False) @@ -717,8 +717,8 @@ def test_recommend_knn_ts(self): top_k=2, seed=123456) self.assertEqual(results[0], [[2, 1], [1, 3]]) - self.assertListAlmostEqual(results[1], [[0.6470729583134509, 0.6239486262002204], - [0.7257397617770284, 0.7239071840518659]]) + self.assertListAlmostEqual(results[1][0], [0.6470729583134509, 0.6239486262002204]) + self.assertListAlmostEqual(results[1][1], [0.7257397617770284, 0.7239071840518659]) # No scores results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False) @@ -739,4 +739,4 @@ def test_recommend_lin_ucb_excluded(self): top_k=2, seed=123456) self.assertEqual(results[0], [[2], [1]]) - self.assertListAlmostEqual(results[1], [[0.8103388262282213], [0.8510415343853225]]) + self.assertListAlmostEqual(results[1][0], [0.8103388262282213])