Skip to content

Commit

Permalink
Merge pull request #3 from fidelity/fix_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bkleyn authored Mar 18, 2022
2 parents 8c75e38 + 986e7e6 commit ce178a6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
21 changes: 18 additions & 3 deletions mab2rec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ class LinGreedy:
The regularization strength.
Integer or float. Must be greater than zero.
Default value is 1.0.
scale: bool
Whether to scale features to have zero mean and unit variance.
Uses StandardScaler in sklearn.preprocessing.
Default value is False.
"""
epsilon: float = 0.1
l2_lambda: float = 1.0
scale: bool = False


@spock
Expand All @@ -59,9 +64,14 @@ class LinTS:
The regularization strength.
Integer or float. Must be greater than zero.
Default value is 1.0.
scale: bool
Whether to scale features to have zero mean and unit variance.
Uses StandardScaler in sklearn.preprocessing.
Default value is False.
"""
alpha: float = 1.0
l2_lambda: float = 1.0
scale: bool = False


@spock
Expand All @@ -78,9 +88,14 @@ class LinUCB:
The regularization strength.
Integer or float. Cannot be negative.
Default value is 1.0.
scale: bool
Whether to scale features to have zero mean and unit variance.
Uses StandardScaler in sklearn.preprocessing.
Default value is False.
"""
alpha: float = 1.0
l2_lambda: float = 1.0
scale: bool = False


@spock
Expand Down Expand Up @@ -283,11 +298,11 @@ def init_recommender(config):
if isinstance(lp_params, EpsilonGreedy):
lp = LearningPolicy.EpsilonGreedy(epsilon=lp_params.epsilon)
elif isinstance(lp_params, LinGreedy):
lp = LearningPolicy.LinGreedy(epsilon=lp_params.epsilon, l2_lambda=lp_params.l2_lambda)
lp = LearningPolicy.LinGreedy(epsilon=lp_params.epsilon, l2_lambda=lp_params.l2_lambda, scale=lp_params.scale)
elif isinstance(lp_params, LinTS):
lp = LearningPolicy.LinTS(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda)
lp = LearningPolicy.LinTS(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda, scale=lp_params.scale)
elif isinstance(lp_params, LinUCB):
lp = LearningPolicy.LinUCB(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda)
lp = LearningPolicy.LinUCB(alpha=lp_params.alpha, l2_lambda=lp_params.l2_lambda, scale=lp_params.scale)
elif isinstance(lp_params, Popularity):
lp = LearningPolicy.Popularity()
elif isinstance(lp_params, Random):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ def test_recommend_lin_greedy(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[3, 1], [1, 3]])
self.assertListAlmostEqual(results[1], [[0.6504125435586658, 0.5240639631785098],
[0.7114885562660387, 0.5354020923417662]])
self.assertListAlmostEqual(results[1][0], [0.6504125435586658, 0.5240639631785098])
self.assertListAlmostEqual(results[1][1], [0.7114885562660387, 0.5354020923417662])

# No scores
results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False)
Expand All @@ -633,8 +633,8 @@ def test_recommend_lin_ucb(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[3, 2], [1, 3]])
self.assertListAlmostEqual(results[1], [[0.8355754378823774, 0.8103388262282213],
[0.8510415343853225, 0.8454457789037026]])
self.assertListAlmostEqual(results[1][0], [0.8355754378823774, 0.8103388262282213])
self.assertListAlmostEqual(results[1][1], [0.8510415343853225, 0.8454457789037026])

# No scores
results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False)
Expand All @@ -654,8 +654,8 @@ def test_recommend_lin_ts(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[2, 3], [3, 1]])
self.assertListAlmostEqual(results[1], [[0.9571299309237765, 0.7351400505873965],
[0.8548770839301622, 0.7726819822665895]])
self.assertListAlmostEqual(results[1][0], [0.9571299309237765, 0.7351400505873965])
self.assertListAlmostEqual(results[1][1], [0.8548770839301622, 0.7726819822665895])

# No scores
results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False)
Expand All @@ -675,8 +675,8 @@ def test_recommend_clusters_ts(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[2, 1], [1, 3]])
self.assertListAlmostEqual(results[1], [[0.6470729583134509, 0.6239486262002204],
[0.7257397617770284, 0.6902019029795886]])
self.assertListAlmostEqual(results[1][0], [0.6470729583134509, 0.6239486262002204])
self.assertListAlmostEqual(results[1][1], [0.7257397617770284, 0.6902019029795886])

# No scores
results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False)
Expand All @@ -696,8 +696,8 @@ def test_recommend_radius_ts(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[1, 3], [3, 2]])
self.assertListAlmostEqual(results[1], [[0.6853064650518793, 0.5794087793326232],
[0.6171485591737581, 0.6039485772665535]])
self.assertListAlmostEqual(results[1][0], [0.6853064650518793, 0.5794087793326232])
self.assertListAlmostEqual(results[1][1], [0.6171485591737581, 0.6039485772665535])

# No scores
results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False)
Expand All @@ -717,8 +717,8 @@ def test_recommend_knn_ts(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[2, 1], [1, 3]])
self.assertListAlmostEqual(results[1], [[0.6470729583134509, 0.6239486262002204],
[0.7257397617770284, 0.7239071840518659]])
self.assertListAlmostEqual(results[1][0], [0.6470729583134509, 0.6239486262002204])
self.assertListAlmostEqual(results[1][1], [0.7257397617770284, 0.7239071840518659])

# No scores
results = rec.recommend([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], return_scores=False)
Expand All @@ -739,4 +739,4 @@ def test_recommend_lin_ucb_excluded(self):
top_k=2,
seed=123456)
self.assertEqual(results[0], [[2], [1]])
self.assertListAlmostEqual(results[1], [[0.8103388262282213], [0.8510415343853225]])
self.assertListAlmostEqual(results[1][0], [0.8103388262282213])

0 comments on commit ce178a6

Please sign in to comment.