diff --git a/NAMESPACE b/NAMESPACE index adb94234..2db9cc73 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,6 +38,8 @@ S3method(finalize_estimator_internal,pr_auc) S3method(finalize_estimator_internal,pr_curve) S3method(finalize_estimator_internal,roc_auc) S3method(finalize_estimator_internal,roc_curve) +S3method(format,metric) +S3method(format,metric_set) S3method(gain_capture,data.frame) S3method(gain_curve,data.frame) S3method(huber_loss,data.frame) @@ -73,6 +75,7 @@ S3method(precision,data.frame) S3method(precision,matrix) S3method(precision,table) S3method(print,conf_mat) +S3method(print,metric) S3method(print,metric_set) S3method(recall,data.frame) S3method(recall,matrix) diff --git a/NEWS.md b/NEWS.md index 4d5260ef..67464486 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,6 +16,8 @@ calculated with `roc_auc_survival()`. * `demographic_parity()`, `equalized_odds()`, and `equal_opportunity()` are new metrics for measuring model fairness. Each is implemented with the `new_groupwise_metric()` constructor, a general interface for defining group-aware metrics that allows for quickly and flexibly defining fairness metrics with the problem context in mind. +* Added a print method for metrics and metric sets (#435). + * All warnings and errors have been updated to use the cli package for increased clarity and consistency. (#456, #457, #458) # yardstick 1.2.0 diff --git a/R/aaa-metrics.R b/R/aaa-metrics.R index 6e4c8b35..1a3d66da 100644 --- a/R/aaa-metrics.R +++ b/R/aaa-metrics.R @@ -278,11 +278,40 @@ metric_set <- function(...) { #' @export print.metric_set <- function(x, ...) { - info <- dplyr::as_tibble(x) - print(info) + cat(format(x), sep = "\n") invisible(x) } +#' @export +format.metric_set <- function(x, ...) { + metrics <- attributes(x)$metrics + names <- names(metrics) + + cli::cli_format_method({ + cli::cli_text("A metric set, consisting of:") + + metric_formats <- vapply(metrics, format, character(1)) + metric_formats <- strsplit(metric_formats, " | ", fixed = TRUE) + + metric_names <- names(metric_formats) + metric_types <- vapply(metric_formats, `[`, character(1), 1, USE.NAMES = FALSE) + metric_descs <- vapply(metric_formats, `[`, character(1), 2) + metric_nchars <- nchar(metric_names) + nchar(metric_types) + metric_desc_paddings <- max(metric_nchars) - metric_nchars + # see r-lib/cli#506 + metric_desc_paddings <- lapply(metric_desc_paddings, rep, x = "\u00a0") + metric_desc_paddings <- vapply(metric_desc_paddings, paste, character(1), collapse = "") + + for (i in seq_along(metrics)) { + cli::cli_text( + "- {.fun {metric_names[i]}}, \\ + {tolower(metric_types[i])}{metric_desc_paddings[i]} | \\ + {metric_descs[i]}" + ) + } + }) +} + #' @export as_tibble.metric_set <- function(x, ...) { metrics <- attributes(x)$metrics diff --git a/R/aaa-new.R b/R/aaa-new.R index 40a1752f..4537d68f 100644 --- a/R/aaa-new.R +++ b/R/aaa-new.R @@ -94,3 +94,41 @@ metric_direction <- function(x) { attr(x, "direction") <- value x } + +#' @noRd +#' @export +print.metric <- function(x, ...) { + cat(format(x), sep = "\n") + invisible(x) +} + +#' @export +format.metric <- function(x, ...) { + first_class <- class(x)[[1]] + metric_type <- + switch( + first_class, + "prob_metric" = "probability metric", + "class_metric" = "class metric", + "numeric_metric" = "numeric metric", + "dynamic_survival_metric" = "dynamic survival metric", + "static_survival_metric" = "static survival metric", + "integrated_survival_metric" = "integrated survival metric", + "metric" + ) + + metric_desc <- "direction: {.field {attr(x, 'direction')}}" + + by_attr <- attr(x, "by") + if (!is.null(by_attr)) { + metric_desc <- + c( + metric_desc, + ", group-wise on: {.field {as.character(by_attr)}}" + ) + } + + cli::cli_format_method( + cli::cli_text(c("A {metric_type} | ", metric_desc)) + ) +} diff --git a/tests/testthat/_snaps/aaa-metrics.md b/tests/testthat/_snaps/aaa-metrics.md index 5d03cb83..7855da7e 100644 --- a/tests/testthat/_snaps/aaa-metrics.md +++ b/tests/testthat/_snaps/aaa-metrics.md @@ -62,12 +62,10 @@ Code metric_set(rmse, rsq, ccc) Output - # A tibble: 3 x 3 - metric class direction - - 1 rmse numeric_metric minimize - 2 rsq numeric_metric maximize - 3 ccc numeric_metric maximize + A metric set, consisting of: + - `rmse()`, a numeric metric | direction: minimize + - `rsq()`, a numeric metric | direction: maximize + - `ccc()`, a numeric metric | direction: maximize # `metric_set()` errors contain env name for unknown functions (#128) diff --git a/tests/testthat/_snaps/aaa-new.md b/tests/testthat/_snaps/aaa-new.md index 913cc1aa..77bb6dcd 100644 --- a/tests/testthat/_snaps/aaa-new.md +++ b/tests/testthat/_snaps/aaa-new.md @@ -15,3 +15,24 @@ ! `direction` must be one of "maximize", "minimize", or "zero", not "min". i Did you mean "minimize"? +# metric print method works + + Code + rmse + Output + A numeric metric | direction: minimize + +--- + + Code + roc_auc + Output + A probability metric | direction: maximize + +--- + + Code + demographic_parity(boop) + Output + A class metric | direction: minimize, group-wise on: boop + diff --git a/tests/testthat/test-aaa-new.R b/tests/testthat/test-aaa-new.R index 4ac4e6db..00f31cd2 100644 --- a/tests/testthat/test-aaa-new.R +++ b/tests/testthat/test-aaa-new.R @@ -37,3 +37,9 @@ test_that("`direction` is validated", { new_class_metric(function() 1, "min") ) }) + +test_that("metric print method works", { + expect_snapshot(rmse) + expect_snapshot(roc_auc) + expect_snapshot(demographic_parity(boop)) +})