Skip to content
Sanjar Ad[yi]lov edited this page Jun 19, 2024 · 1 revision

What is scikit-fallback?

scikit-fallback is a scikit-learn-compatible Python package to for machine learning with a reject option.

It offers various tools to make your estimators and scorers support fallbacks, or rejections - special labels indicating that your machine learning pipeline abstains from making decisions. Assume that your domain can have:

  • more classes than what your classifier was trained on (e.g., an unexpected buy_pizza intent encountered by your dialogue systems for bank applications);
  • ambiguous examples (e.g., an image of both a cat and a dog passed to a cat-vs-dog classifier);
  • classes with high misclassification costs (e.g., false-negatives in cancer diagnosis).

You might want to leverage additional experts like humans to tackle such anomalies. scikit-fallback can wrap your estimators and scorers, and also offer additional objects to either predict fallback labels so that your pipelines hand the corresponding samples off to other systems, or store fallback masks to evaluate the ability of your pipelines to predict and reject correctly.

Here is a quick example of a binary classification with a reject option:

>>> from sklearn.datasets import make_moons
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skfb.estimators import RateFallbackClassifierCV
>>> from skfb.metrics import predict_accept_confusion_matrix, predict_reject_accuracy_score
>>> X, y = make_moons(n_samples=4_000, noise=0.4, random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
...     X, y, test_size=0.2, random_state=0, shuffle=True)
>>> estimator = LogisticRegression(C=10_000, random_state=0)
>>> # RateFallbackClassifierCV learns the fallback threshold t s.t. Pr{x|t} ~ fallback_rate
>>> rejector = RateFallbackClassifierCV(
...     estimator, fallback_rates=(1 / 5, 1 / 4, 1 / 3), cv=3).fit(X_train, y_train)
>>> y_pred = rejector.set_params(fallback_mode="predict").predict(X_test)
>>> # PA confusion matrix is a 2x2 matrix w/ rows = accuracy and columns = acceptance
>>> predict_accept_confusion_matrix(y_test, y_pred)
array([[ 46,  59],
       [136, 559]])
>>> # PR accuracy summarizes the matrix above: (TA + TR) / (TA + TR + FA + FR)
>>> predict_reject_accuracy_score(y_test, y_pred)
0.75625

Installation

scikit-fallback depends on scikit-learn (>=1.3) and can be installed via pip:

pip install -U scikit-fallback
Clone this wiki locally