-
Notifications
You must be signed in to change notification settings - Fork 0
Home
scikit-fallback
is a scikit-learn-compatible Python package to for machine learning with a reject option.
It offers various tools to make your estimators and scorers support fallbacks, or rejections - special labels indicating that your machine learning pipeline abstains from making decisions. Assume that your domain can have:
- more classes than what your classifier was trained on (e.g., an unexpected buy_pizza intent encountered by your dialogue systems for bank applications);
- ambiguous examples (e.g., an image of both a cat and a dog passed to a cat-vs-dog classifier);
- classes with high misclassification costs (e.g., false-negatives in cancer diagnosis).
You might want to leverage additional experts like humans to tackle such anomalies. scikit-fallback
can wrap
your estimators and scorers, and also offer additional objects to either predict fallback labels so that your pipelines
hand the corresponding samples off to other systems, or store fallback masks to evaluate the ability of your pipelines
to predict and reject correctly.
Here is a quick example of a binary classification with a reject option:
>>> from sklearn.datasets import make_moons
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skfb.estimators import RateFallbackClassifierCV
>>> from skfb.metrics import predict_accept_confusion_matrix, predict_reject_accuracy_score
>>> X, y = make_moons(n_samples=4_000, noise=0.4, random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, test_size=0.2, random_state=0, shuffle=True)
>>> estimator = LogisticRegression(C=10_000, random_state=0)
>>> # RateFallbackClassifierCV learns the fallback threshold t s.t. Pr{x|t} ~ fallback_rate
>>> rejector = RateFallbackClassifierCV(
... estimator, fallback_rates=(1 / 5, 1 / 4, 1 / 3), cv=3).fit(X_train, y_train)
>>> y_pred = rejector.set_params(fallback_mode="predict").predict(X_test)
>>> # PA confusion matrix is a 2x2 matrix w/ rows = accuracy and columns = acceptance
>>> predict_accept_confusion_matrix(y_test, y_pred)
array([[ 46, 59],
[136, 559]])
>>> # PR accuracy summarizes the matrix above: (TA + TR) / (TA + TR + FA + FR)
>>> predict_reject_accuracy_score(y_test, y_pred)
0.75625
scikit-fallback
depends on scikit-learn (>=1.3)
and can be installed via pip
:
pip install -U scikit-fallback