From 6c835cf808782cb62309f264d57f31fa30a7df32 Mon Sep 17 00:00:00 2001 From: Vincent Date: Tue, 23 Nov 2021 21:47:40 +0100 Subject: [PATCH 1/3] added-margin-reason --- doubtlab/reason.py | 43 ++++++++++++++++++++++++++++++++++++ tests/test_docs.py | 2 ++ tests/test_general_reason.py | 2 ++ 3 files changed, 47 insertions(+) diff --git a/doubtlab/reason.py b/doubtlab/reason.py index d3a77d1..9f3f3a8 100644 --- a/doubtlab/reason.py +++ b/doubtlab/reason.py @@ -151,6 +151,49 @@ def __call__(self, X, y): return np.where(confidences > self.threshold, confidences, 0) +class MarginConfidenceReason: + """ + Assign doubt when a the difference between the top two most confident classes is too large. + + Throws an error when there are only two classes. + + Arguments: + model: scikit-learn classifier + threshold: confidence threshold for doubt assignment + + Usage: + + ```python + from sklearn.datasets import load_iris + from sklearn.linear_model import LogisticRegression + + from doubtlab.ensemble import DoubtEnsemble + from doubtlab.reason import MarginConfidenceReason + + X, y = load_iris(return_X_y=True) + model = LogisticRegression(max_iter=1_000) + model.fit(X, y) + + doubt = DoubtEnsemble(reason = MarginConfidenceReason(model=model)) + + indices = doubt.get_indices(X, y) + ``` + """ + + def __init__(self, model, threshold=0.2): + self.model = model + self.threshold = threshold + + def _calc_margin(self, X, y): + probas = self.model.predict_proba(X) + sorted = np.sort(probas, axis=1) + return sorted[:, -1] - sorted[:, -2] + + def __call__(self, X, y): + margin = self._calc_margin(X, y) + return np.where(margin > self.threshold, margin, 0) + + class ShortConfidenceReason: """ Assign doubt when the correct class gains too little confidence. diff --git a/tests/test_docs.py b/tests/test_docs.py index 5a4a35f..11d9221 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -8,6 +8,7 @@ DisagreeReason, LongConfidenceReason, ShortConfidenceReason, + MarginConfidenceReason, WrongPredictionReason, AbsoluteDifferenceReason, RelativeDifferenceReason, @@ -22,6 +23,7 @@ DisagreeReason, LongConfidenceReason, ShortConfidenceReason, + MarginConfidenceReason, WrongPredictionReason, AbsoluteDifferenceReason, RelativeDifferenceReason, diff --git a/tests/test_general_reason.py b/tests/test_general_reason.py index 2620651..fcdf1e0 100644 --- a/tests/test_general_reason.py +++ b/tests/test_general_reason.py @@ -10,6 +10,7 @@ ProbaReason, OutlierReason, DisagreeReason, + MarginConfidenceReason, LongConfidenceReason, ShortConfidenceReason, WrongPredictionReason, @@ -22,6 +23,7 @@ ProbaReason, LongConfidenceReason, ShortConfidenceReason, + MarginConfidenceReason, WrongPredictionReason, CleanlabReason, ] From 96026f790560cca625f1471a6c7af10a2ded03e0 Mon Sep 17 00:00:00 2001 From: Vincent Date: Tue, 23 Nov 2021 21:55:49 +0100 Subject: [PATCH 2/3] added-another-test --- doubtlab/reason.py | 6 +++--- tests/test_reason/test_margin.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 tests/test_reason/test_margin.py diff --git a/doubtlab/reason.py b/doubtlab/reason.py index 9f3f3a8..cfc57d3 100644 --- a/doubtlab/reason.py +++ b/doubtlab/reason.py @@ -184,13 +184,13 @@ def __init__(self, model, threshold=0.2): self.model = model self.threshold = threshold - def _calc_margin(self, X, y): - probas = self.model.predict_proba(X) + def _calc_margin(self, probas): sorted = np.sort(probas, axis=1) return sorted[:, -1] - sorted[:, -2] def __call__(self, X, y): - margin = self._calc_margin(X, y) + probas = self.model.predict_proba(X) + margin = self._calc_margin(probas) return np.where(margin > self.threshold, margin, 0) diff --git a/tests/test_reason/test_margin.py b/tests/test_reason/test_margin.py new file mode 100644 index 0000000..9080155 --- /dev/null +++ b/tests/test_reason/test_margin.py @@ -0,0 +1,17 @@ +import numpy as np +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression + +from doubtlab.reason import MarginConfidenceReason + + +def test_margin_confidence_margin(): + """Ensures margin is calculated correctly.""" + X, y = load_iris(return_X_y=True) + model = LogisticRegression(max_iter=1_000) + model.fit(X, y) + + reason = MarginConfidenceReason(model=model) + probas = np.eye(3) + margin = reason._calc_margin(probas=probas) + assert np.all(np.isclose(margin, np.ones(3))) From a7733b13e356974c61846dfe8ca4c1905a850ce7 Mon Sep 17 00:00:00 2001 From: Vincent Date: Tue, 23 Nov 2021 21:59:37 +0100 Subject: [PATCH 3/3] version-0.1.2 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 68a22fd..d21e917 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ setup( name="doubtlab", - version="0.1.1", + version="0.1.2", author="Vincent D. Warmerdam", packages=find_packages(exclude=["notebooks", "docs"]), description="Don't Blindly Trust Your Labels",