Skip to content

Commit

Permalink
ENH add support for array API to various metric (scikit-learn#29709)
Browse files Browse the repository at this point in the history
Co-authored-by: Omar Salman <[email protected]>
Co-authored-by: Adrin Jalali <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent fde6f2d commit e92dd40
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 39 deletions.
3 changes: 3 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Metrics
- :func:`sklearn.metrics.mean_gamma_deviance`
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_squared_log_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
Expand All @@ -134,6 +135,8 @@ Metrics
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
- :func:`sklearn.metrics.pairwise.sigmoid_kernel`
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.root_mean_squared_error`
- :func:`sklearn.metrics.root_mean_squared_log_error`
- :func:`sklearn.metrics.zero_one_loss`

Tools
Expand Down
16 changes: 16 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_poisson_deviance` :pr:`29227` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.mean_squared_log_error` :pr:`29709` by :user:`Virgil Chan <virchan>`;
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
- :func:`sklearn.metrics.root_mean_squared_error` :pr:`29709` by :user:`Virgil Chan <virchan>`;
- :func:`sklearn.metrics.root_mean_squared_log_error` :pr:`29709` by :user:`Virgil Chan <virchan>`;
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.cosine_distances` :pr:`29265` by :user:`Emily Chen <EmilyXinyi>`;
Expand Down Expand Up @@ -313,6 +316,19 @@ Changelog
is renamed into `ensure_all_finite`. `force_all_finite` will be removed in 1.8.
:pr:`29404` by :user:`Jérémie du Boisberranger <jeremiedb>`.

- |Fix| the functions :func:`metrics.mean_squared_log_error` and
:func:`metrics.root_mean_squared_log_error` now check whether
the inputs are within the correct domain for the function
:math:`y=\log(1+x)`, rather than :math:`y=\log(x)`.
:pr:`29709` by :user:`Virgil Chan <virchan>`.

- |Fix| the functions :func:`metrics.mean_absolute_error`,
:func:`metrics.mean_absolute_percentage_error`, :func:`metrics.mean_squared_error`
and :func:`metrics.root_mean_squared_error` now explicitly check whether a scalar
will be returned when `multioutput=uniform_average`.
:pr:`29709` by :user:`Virgil Chan <virchan>`.


:mod:`sklearn.model_selection`
..............................

Expand Down
80 changes: 52 additions & 28 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,12 @@ def mean_absolute_error(
multioutput = None

# Average across the outputs (if needed).
# The second call to `_average` should always return
# a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value.
# Therefore, `axis=None`.
mean_absolute_error = _average(output_errors, weights=multioutput)

# Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average
# should always return a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value, irrespective of the
# Array API implementation.
assert mean_absolute_error.shape == ()
return float(mean_absolute_error)


Expand Down Expand Up @@ -416,8 +415,13 @@ def mean_absolute_percentage_error(
# pass None as weights to _average: uniform mean
multioutput = None

# Average across the outputs (if needed).
# The second call to `_average` should always return
# a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value.
# Therefore, `axis=None`.
mean_absolute_percentage_error = _average(output_errors, weights=multioutput)
assert mean_absolute_percentage_error.shape == ()

return float(mean_absolute_percentage_error)


Expand Down Expand Up @@ -524,12 +528,16 @@ def mean_squared_error(
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

# See comment in mean_absolute_error
# Average across the outputs (if needed).
# The second call to `_average` should always return
# a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value.
# Therefore, `axis=None`.
mean_squared_error = _average(output_errors, weights=multioutput)
assert mean_squared_error.shape == ()

return float(mean_squared_error)


Expand Down Expand Up @@ -585,13 +593,16 @@ def root_mean_squared_error(
>>> y_true = [3, -0.5, 2, 7]
>>> y_pred = [2.5, 0.0, 2, 8]
>>> root_mean_squared_error(y_true, y_pred)
np.float64(0.612...)
0.612...
>>> y_true = [[0.5, 1],[-1, 1],[7, -6]]
>>> y_pred = [[0, 2],[-1, 2],[8, -5]]
>>> root_mean_squared_error(y_true, y_pred)
np.float64(0.822...)
0.822...
"""
output_errors = np.sqrt(

xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)

output_errors = xp.sqrt(
mean_squared_error(
y_true, y_pred, sample_weight=sample_weight, multioutput="raw_values"
)
Expand All @@ -601,10 +612,17 @@ def root_mean_squared_error(
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)
# Average across the outputs (if needed).
# The second call to `_average` should always return
# a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value.
# Therefore, `axis=None`.
root_mean_squared_error = _average(output_errors, weights=multioutput)

return float(root_mean_squared_error)


@validate_params(
Expand Down Expand Up @@ -700,20 +718,22 @@ def mean_squared_log_error(
y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput
)

y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
xp, _ = get_namespace(y_true, y_pred)
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)

_, y_true, y_pred, _ = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
)
check_consistent_length(y_true, y_pred, sample_weight)

if (y_true < 0).any() or (y_pred < 0).any():
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
raise ValueError(
"Mean Squared Logarithmic Error cannot be used when "
"targets contain negative values."
"targets contain values less than or equal to -1."
)

return mean_squared_error(
np.log1p(y_true),
np.log1p(y_pred),
xp.log1p(y_true),
xp.log1p(y_pred),
sample_weight=sample_weight,
multioutput=multioutput,
)
Expand Down Expand Up @@ -773,20 +793,24 @@ def root_mean_squared_log_error(
>>> y_true = [3, 5, 2.5, 7]
>>> y_pred = [2.5, 5, 4, 8]
>>> root_mean_squared_log_error(y_true, y_pred)
np.float64(0.199...)
0.199...
"""
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)
xp, _ = get_namespace(y_true, y_pred)
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
)

if (y_true < 0).any() or (y_pred < 0).any():
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
raise ValueError(
"Root Mean Squared Logarithmic Error cannot be used when "
"targets contain negative values."
"targets contain values less than or equal to -1."
)

return root_mean_squared_error(
np.log1p(y_true),
np.log1p(y_pred),
xp.log1p(y_true),
xp.log1p(y_pred),
sample_weight=sample_weight,
multioutput=multioutput,
)
Expand Down
49 changes: 49 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
mean_pinball_loss,
mean_poisson_deviance,
mean_squared_error,
mean_squared_log_error,
mean_tweedie_deviance,
median_absolute_error,
multilabel_confusion_matrix,
Expand All @@ -47,6 +48,8 @@
recall_score,
roc_auc_score,
roc_curve,
root_mean_squared_error,
root_mean_squared_log_error,
top_k_accuracy_score,
zero_one_loss,
)
Expand Down Expand Up @@ -120,11 +123,14 @@
"max_error": max_error,
"mean_absolute_error": mean_absolute_error,
"mean_squared_error": mean_squared_error,
"mean_squared_log_error": mean_squared_log_error,
"mean_pinball_loss": mean_pinball_loss,
"median_absolute_error": median_absolute_error,
"mean_absolute_percentage_error": mean_absolute_percentage_error,
"explained_variance_score": explained_variance_score,
"r2_score": partial(r2_score, multioutput="variance_weighted"),
"root_mean_squared_error": root_mean_squared_error,
"root_mean_squared_log_error": root_mean_squared_log_error,
"mean_normal_deviance": partial(mean_tweedie_deviance, power=0),
"mean_poisson_deviance": mean_poisson_deviance,
"mean_gamma_deviance": mean_gamma_deviance,
Expand Down Expand Up @@ -458,7 +464,10 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"mean_absolute_error",
"median_absolute_error",
"mean_squared_error",
"mean_squared_log_error",
"r2_score",
"root_mean_squared_error",
"root_mean_squared_log_error",
"explained_variance_score",
"mean_absolute_percentage_error",
"mean_pinball_loss",
Expand All @@ -482,6 +491,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"micro_f1_score",
"macro_f1_score",
"weighted_recall_score",
"mean_squared_log_error",
"root_mean_squared_error",
"root_mean_squared_log_error",
# P = R = F = accuracy in multiclass case
"micro_f0.5_score",
"micro_f1_score",
Expand Down Expand Up @@ -551,6 +563,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"d2_tweedie_score",
}

# Metrics involving y = log(1+x)
METRICS_WITH_LOG1P_Y = {
"mean_squared_log_error",
"root_mean_squared_log_error",
}


def _require_positive_targets(y1, y2):
"""Make targets strictly positive"""
Expand All @@ -560,6 +578,16 @@ def _require_positive_targets(y1, y2):
return y1, y2


def _require_log1p_targets(y1, y2):
"""Make targets strictly larger than -1"""
offset = abs(min(y1.min(), y2.min())) - 0.99
y1 = y1.astype(float)
y2 = y2.astype(float)
y1 += offset
y2 += offset
return y1, y2


def test_symmetry_consistency():
# We shouldn't forget any metrics
assert (
Expand All @@ -582,6 +610,9 @@ def test_symmetric_metric(name):
if name in METRICS_REQUIRE_POSITIVE_Y:
y_true, y_pred = _require_positive_targets(y_true, y_pred)

elif name in METRICS_WITH_LOG1P_Y:
y_true, y_pred = _require_log1p_targets(y_true, y_pred)

y_true_bin = random_state.randint(0, 2, size=(20, 25))
y_pred_bin = random_state.randint(0, 2, size=(20, 25))

Expand Down Expand Up @@ -631,6 +662,8 @@ def test_sample_order_invariance(name):

if name in METRICS_REQUIRE_POSITIVE_Y:
y_true, y_pred = _require_positive_targets(y_true, y_pred)
elif name in METRICS_WITH_LOG1P_Y:
y_true, y_pred = _require_log1p_targets(y_true, y_pred)

y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0)

Expand Down Expand Up @@ -698,6 +731,8 @@ def test_format_invariance_with_1d_vectors(name):

if name in METRICS_REQUIRE_POSITIVE_Y:
y1, y2 = _require_positive_targets(y1, y2)
elif name in METRICS_WITH_LOG1P_Y:
y1, y2 = _require_log1p_targets(y1, y2)

y1_list = list(y1)
y2_list = list(y2)
Expand Down Expand Up @@ -986,6 +1021,8 @@ def check_single_sample(name):
# assert that no exception is thrown
if name in METRICS_REQUIRE_POSITIVE_Y:
values = [1, 2]
elif name in METRICS_WITH_LOG1P_Y:
values = [-0.7, 1]
else:
values = [0, 1]
for i, j in product(values, repeat=2):
Expand Down Expand Up @@ -2017,6 +2054,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_squared_log_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
d2_tweedie_score: [
check_array_api_regression_metric,
],
Expand All @@ -2036,6 +2077,14 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
linear_kernel: [check_array_api_metric_pairwise],
polynomial_kernel: [check_array_api_metric_pairwise],
rbf_kernel: [check_array_api_metric_pairwise],
root_mean_squared_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
root_mean_squared_log_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
sigmoid_kernel: [check_array_api_metric_pairwise],
}

Expand Down
20 changes: 12 additions & 8 deletions sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,33 @@ def test_regression_metrics_at_limits():
assert_almost_equal(s([1, 1], [1, 1]), 1.0)
assert_almost_equal(s([1, 1], [1, 1], force_finite=False), np.nan)
msg = (
"Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
"Mean Squared Logarithmic Error cannot be used when "
"targets contain values less than or equal to -1."
)
with pytest.raises(ValueError, match=msg):
mean_squared_log_error([-1.0], [-1.0])
msg = (
"Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
"Mean Squared Logarithmic Error cannot be used when "
"targets contain values less than or equal to -1."
)
with pytest.raises(ValueError, match=msg):
mean_squared_log_error([1.0, 2.0, 3.0], [1.0, -2.0, 3.0])
msg = (
"Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
"Mean Squared Logarithmic Error cannot be used when "
"targets contain values less than or equal to -1."
)
with pytest.raises(ValueError, match=msg):
mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])
msg = (
"Root Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
"Mean Squared Logarithmic Error cannot be used when "
"targets contain values less than or equal to -1."
)
with pytest.raises(ValueError, match=msg):
root_mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])
msg = (
"Root Mean Squared Logarithmic Error cannot be used when "
"targets contain values less than or equal to -1."
)

# Tweedie deviance error
power = -1.2
Expand Down
Loading

0 comments on commit e92dd40

Please sign in to comment.