Skip to content

Commit

Permalink
Merge pull request #455 from tidymodels/print-435
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Dec 1, 2023
2 parents 2f556df + 2a11166 commit b9ea40f
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 8 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ S3method(finalize_estimator_internal,pr_auc)
S3method(finalize_estimator_internal,pr_curve)
S3method(finalize_estimator_internal,roc_auc)
S3method(finalize_estimator_internal,roc_curve)
S3method(format,metric)
S3method(format,metric_set)
S3method(gain_capture,data.frame)
S3method(gain_curve,data.frame)
S3method(huber_loss,data.frame)
Expand Down Expand Up @@ -73,6 +75,7 @@ S3method(precision,data.frame)
S3method(precision,matrix)
S3method(precision,table)
S3method(print,conf_mat)
S3method(print,metric)
S3method(print,metric_set)
S3method(recall,data.frame)
S3method(recall,matrix)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ calculated with `roc_auc_survival()`.

* `demographic_parity()`, `equalized_odds()`, and `equal_opportunity()` are new metrics for measuring model fairness. Each is implemented with the `new_groupwise_metric()` constructor, a general interface for defining group-aware metrics that allows for quickly and flexibly defining fairness metrics with the problem context in mind.

* Added a print method for metrics and metric sets (#435).

* All warnings and errors have been updated to use the cli package for increased clarity and consistency. (#456, #457, #458)

# yardstick 1.2.0
Expand Down
33 changes: 31 additions & 2 deletions R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,40 @@ metric_set <- function(...) {

#' @export
print.metric_set <- function(x, ...) {
info <- dplyr::as_tibble(x)
print(info)
cat(format(x), sep = "\n")
invisible(x)
}

#' @export
format.metric_set <- function(x, ...) {
metrics <- attributes(x)$metrics
names <- names(metrics)

cli::cli_format_method({
cli::cli_text("A metric set, consisting of:")

metric_formats <- vapply(metrics, format, character(1))
metric_formats <- strsplit(metric_formats, " | ", fixed = TRUE)

metric_names <- names(metric_formats)
metric_types <- vapply(metric_formats, `[`, character(1), 1, USE.NAMES = FALSE)
metric_descs <- vapply(metric_formats, `[`, character(1), 2)
metric_nchars <- nchar(metric_names) + nchar(metric_types)
metric_desc_paddings <- max(metric_nchars) - metric_nchars
# see r-lib/cli#506
metric_desc_paddings <- lapply(metric_desc_paddings, rep, x = "\u00a0")
metric_desc_paddings <- vapply(metric_desc_paddings, paste, character(1), collapse = "")

for (i in seq_along(metrics)) {
cli::cli_text(
"- {.fun {metric_names[i]}}, \\
{tolower(metric_types[i])}{metric_desc_paddings[i]} | \\
{metric_descs[i]}"
)
}
})
}

#' @export
as_tibble.metric_set <- function(x, ...) {
metrics <- attributes(x)$metrics
Expand Down
38 changes: 38 additions & 0 deletions R/aaa-new.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,41 @@ metric_direction <- function(x) {
attr(x, "direction") <- value
x
}

#' @noRd
#' @export
print.metric <- function(x, ...) {
cat(format(x), sep = "\n")
invisible(x)
}

#' @export
format.metric <- function(x, ...) {
first_class <- class(x)[[1]]
metric_type <-
switch(
first_class,
"prob_metric" = "probability metric",
"class_metric" = "class metric",
"numeric_metric" = "numeric metric",
"dynamic_survival_metric" = "dynamic survival metric",
"static_survival_metric" = "static survival metric",
"integrated_survival_metric" = "integrated survival metric",
"metric"
)

metric_desc <- "direction: {.field {attr(x, 'direction')}}"

by_attr <- attr(x, "by")
if (!is.null(by_attr)) {
metric_desc <-
c(
metric_desc,
", group-wise on: {.field {as.character(by_attr)}}"
)
}

cli::cli_format_method(
cli::cli_text(c("A {metric_type} | ", metric_desc))
)
}
10 changes: 4 additions & 6 deletions tests/testthat/_snaps/aaa-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@
Code
metric_set(rmse, rsq, ccc)
Output
# A tibble: 3 x 3
metric class direction
<chr> <chr> <chr>
1 rmse numeric_metric minimize
2 rsq numeric_metric maximize
3 ccc numeric_metric maximize
A metric set, consisting of:
- `rmse()`, a numeric metric | direction: minimize
- `rsq()`, a numeric metric | direction: maximize
- `ccc()`, a numeric metric | direction: maximize

# `metric_set()` errors contain env name for unknown functions (#128)

Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/_snaps/aaa-new.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,24 @@
! `direction` must be one of "maximize", "minimize", or "zero", not "min".
i Did you mean "minimize"?

# metric print method works

Code
rmse
Output
A numeric metric | direction: minimize

---

Code
roc_auc
Output
A probability metric | direction: maximize

---

Code
demographic_parity(boop)
Output
A class metric | direction: minimize, group-wise on: boop

6 changes: 6 additions & 0 deletions tests/testthat/test-aaa-new.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ test_that("`direction` is validated", {
new_class_metric(function() 1, "min")
)
})

test_that("metric print method works", {
expect_snapshot(rmse)
expect_snapshot(roc_auc)
expect_snapshot(demographic_parity(boop))
})

0 comments on commit b9ea40f

Please sign in to comment.