Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to cli in misc.R and metrics #457

Merged
merged 4 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions R/class-kap.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,19 @@ make_weighting_matrix <- function(weighting, n_levels, call = caller_env()) {

validate_weighting <- function(x, call = caller_env()) {
if (!is_string(x)) {
abort("`weighting` must be a string.", call = call)
cli::cli_abort("{.arg weighting} must be a string.", call = call)
}

ok <- is_no_weighting(x) ||
is_linear_weighting(x) ||
is_quadratic_weighting(x)

if (!ok) {
abort("`weighting` must be 'none', 'linear', or 'quadratic'.", call = call)
cli::cli_abort(
"{.arg weighting} must be {.val none}, {.val linear}, or \\
{.val quadratic}, not {.val {x}}.",
call = call
)
}

invisible(x)
Expand Down
5 changes: 4 additions & 1 deletion R/class-mcc.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ mcc_multiclass_impl <- function(C) {

check_mcc_data <- function(data) {
if (!is.double(data) && !is.matrix(data)) {
abort("`data` should be a double matrix at this point.", .internal = TRUE)
cli::cli_abort(
"{.arg data} should be a double matrix at this point.",
.internal = TRUE
)
}
invisible()
}
9 changes: 7 additions & 2 deletions R/conf_mat.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ conf_mat_impl <- function(truth, estimate, case_weights, call = caller_env()) {
check_class_metric(truth, estimate, case_weights, estimator, call = call)

if (length(levels(truth)) < 2) {
abort("`truth` must have at least 2 factor levels.", call = call)
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
call = call
)
}

yardstick_table(
Expand All @@ -245,7 +248,9 @@ conf_mat.table <- function(data, ...) {
num_lev <- length(class_lev)

if (num_lev < 2) {
abort("There must be at least 2 factors levels in the `data`")
cli::cli_abort(
"There must be at least 2 factors levels in the {.arg data}."
)
}

structure(
Expand Down
79 changes: 57 additions & 22 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

pos_val <- function(xtab, event_level) {
if (!all(dim(xtab) == 2)) {
abort("Only relevant for 2x2 tables")
cli::cli_abort("Only relevant for 2x2 tables.")
}

if (is_event_first(event_level)) {
Expand All @@ -16,7 +16,7 @@ pos_val <- function(xtab, event_level) {

neg_val <- function(xtab, event_level) {
if (!all(dim(xtab) == 2)) {
abort("Only relevant for 2x2 tables")
cli::cli_abort("Only relevant for 2x2 tables.")
}

if (is_event_first(event_level)) {
Expand Down Expand Up @@ -67,19 +67,19 @@ as_factor_from_class_pred <- function(x) {
}

if (!is_installed("probably")) {
abort(paste0(
"A <class_pred> input was detected, but the probably package ",
"isn't installed. Install probably to be able to convert <class_pred> ",
"to <factor>."
))
cli::cli_abort(
"A {.cls class_pred} input was detected, but the {.pkg probably} \\
package isn't installed. Install {.pkg probably} to be able to convert \\
{.cls class_pred} to {.cls factor}."
)
}
probably::as.factor(x)
}

abort_if_class_pred <- function(x, call = caller_env()) {
if (is_class_pred(x)) {
abort(
"`truth` should not a `class_pred` object.",
cli::cli_abort(
"{.arg truth} should not a {.cls class_pred} object.",
call = call
)
}
Expand Down Expand Up @@ -186,10 +186,18 @@ yardstick_cov <- function(truth,

size <- vec_size(truth)
if (size != vec_size(estimate)) {
abort("`truth` and `estimate` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg estimate} ({vec_size(estimate)}) must be the same size.",
.internal = TRUE
)
}
if (size != vec_size(case_weights)) {
abort("`truth` and `case_weights` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg case_weights} ({vec_size(case_weights)}) must be the same size.",
.internal = TRUE
)
}

if (size == 0L || size == 1L) {
Expand Down Expand Up @@ -232,10 +240,18 @@ yardstick_cor <- function(truth,

size <- vec_size(truth)
if (size != vec_size(estimate)) {
abort("`truth` and `estimate` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg estimate} ({vec_size(estimate)}) must be the same size.",
.internal = TRUE
)
}
if (size != vec_size(case_weights)) {
abort("`truth` and `case_weights` must be the same size.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({vec_size(truth)}) and \\
{.arg case_weights} ({vec_size(case_weights)}) must be the same size.",
.internal = TRUE
)
}

if (size == 0L || size == 1L) {
Expand Down Expand Up @@ -345,14 +361,17 @@ weighted_quantile <- function(x, weights, probabilities) {

size <- vec_size(x)
if (size != vec_size(weights)) {
abort("`x` and `weights` must have the same size.")
cli::cli_abort(
"{.arg x} ({vec_size(x)}) and {.arg weights} ({vec_size(weights)}) \\
must have the same size."
)
}

if (any(is.na(probabilities))) {
abort("`probabilities` can't be missing.")
cli::cli_abort("{.arg probabilities} can't have missing values.")
}
if (any(probabilities > 1 | probabilities < 0)) {
abort("`probabilities` must be within `[0, 1]`.")
cli::cli_abort("{.arg probabilities} must be within `[0, 1]`.")
}

if (size == 0L) {
Expand Down Expand Up @@ -397,20 +416,33 @@ yardstick_table <- function(truth, estimate, ..., case_weights = NULL) {
}

if (!is.factor(truth)) {
abort("`truth` must be a factor.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must be a factor, not {.obj_type_friendly {truth}}.",
.internal = TRUE
)
}
if (!is.factor(estimate)) {
abort("`estimate` must be a factor.", .internal = TRUE)
cli::cli_abort(
"{.arg estimate} must be a factor, not {.obj_type_friendly {estimate}}.",
.internal = TRUE
)
}

levels <- levels(truth)
n_levels <- length(levels)

if (!identical(levels, levels(estimate))) {
abort("`truth` and `estimate` must have the same levels in the same order.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} and {.arg estimate} must have the same levels in the same \\
order.",
.internal = TRUE
)
}
if (n_levels < 2) {
abort("`truth` must have at least 2 factor levels.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
.internal = TRUE
)
}

# Supply `estimate` first to get it to correspond to the row names.
Expand Down Expand Up @@ -447,14 +479,17 @@ yardstick_truth_table <- function(truth, ..., case_weights = NULL) {
abort_if_class_pred(truth)

if (!is.factor(truth)) {
abort("`truth` must be a factor.", .internal = TRUE)
cli::cli_abort("{.arg truth} must be a factor.", .internal = TRUE)
}

levels <- levels(truth)
n_levels <- length(levels)

if (n_levels < 2) {
abort("`truth` must have at least 2 factor levels.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must have at least 2 factor levels.",
.internal = TRUE
)
}

# Always return a double matrix for type stability
Expand Down
7 changes: 1 addition & 6 deletions R/num-huber_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,7 @@ huber_loss_impl <- function(truth,
# Weighted Huber Loss implementation confirmed against matlab:
# https://www.mathworks.com/help/deeplearning/ref/dlarray.huber.html

if (!is_bare_numeric(delta, n = 1L)) {
abort("`delta` must be a single numeric value.", call = call)
}
if (!(delta >= 0)) {
abort("`delta` must be a positive value.", call = call)
}
check_number_decimal(delta, min = 0, call = call)

a <- truth - estimate
abs_a <- abs(a)
Expand Down
48 changes: 7 additions & 41 deletions R/num-mase.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ mase_impl <- function(truth,
mae_train = NULL,
case_weights = NULL,
call = caller_env()) {
validate_m(m, call = call)
validate_mae_train(mae_train, call = call)
check_number_whole(m, min = 0, call = call)
check_number_decimal(mae_train, min = 0, allow_null = TRUE, call = call)

if (is.null(mae_train)) {
validate_truth_m(truth, m, call = call)
Expand All @@ -139,46 +139,12 @@ mase_impl <- function(truth,
out
}

validate_m <- function(m, call = caller_env()) {
abort_msg <- "`m` must be a single positive integer value."

if (!is_integerish(m, n = 1L)) {
abort(abort_msg, call = call)
}

if (!(m > 0)) {
abort(abort_msg, call = call)
}

invisible(m)
}

validate_mae_train <- function(mae_train, call = caller_env()) {
if (is.null(mae_train)) {
return(invisible(mae_train))
}

is_single_numeric <- is_bare_numeric(mae_train, n = 1L)
abort_msg <- "`mae_train` must be a single positive numeric value."

if (!is_single_numeric) {
abort(abort_msg, call = call)
}

if (!(mae_train > 0)) {
abort(abort_msg, call = call)
}

invisible(mae_train)
}

validate_truth_m <- function(truth, m, call = caller_env()) {
if (length(truth) <= m) {
abort(paste0(
"`truth` must have a length greater than `m` ",
"to compute the out-of-sample naive mean absolute error."
), call = call)
cli::cli_abort(
"{.arg truth} ({length(truth)}) must have a length greater than \\
{.arg m} ({m}) to compute the out-of-sample naive mean absolute error.",
call = call
)
}

invisible(truth)
}
7 changes: 1 addition & 6 deletions R/num-pseudo_huber_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,7 @@ huber_loss_pseudo_impl <- function(truth,
delta,
case_weights,
call = caller_env()) {
if (!is_bare_numeric(delta, n = 1L)) {
abort("`delta` must be a single numeric value.", call = call)
}
if (!(delta >= 0)) {
abort("`delta` must be a positive value.", call = call)
}
check_number_decimal(delta, min = 0, call = call)

a <- truth - estimate
loss <- delta^2 * (sqrt(1 + (a / delta)^2) - 1)
Expand Down
27 changes: 22 additions & 5 deletions R/prob-binary-thresholds.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,36 @@ binary_threshold_curve <- function(truth,
case_weights <- vec_cast(case_weights, to = double())

if (!is.factor(truth)) {
abort("`truth` must be a factor.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must be a factor, not {.obj_friendly_type {truth}}.",
.internal = TRUE
)
}
if (length(levels(truth)) != 2L) {
abort("`truth` must have two levels.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} must have two levels, not {length(levels(truth))}.",
.internal = TRUE
)
}
if (!is.numeric(estimate)) {
abort("`estimate` must be numeric.", .internal = TRUE)
cli::cli_abort(
"{.arg estimate} must be numeric, not {.obj_friendly_type {estimate}}.",
.internal = TRUE
)
}
if (length(truth) != length(estimate)) {
abort("`truth` and `estimate` must be the same length.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({length(truth)}) and \\
{.arg estimate} ({length(estimate)}) must be the same length.",
.internal = TRUE
)
}
if (length(truth) != length(case_weights)) {
abort("`truth` and `case_weights` must be the same length.", .internal = TRUE)
cli::cli_abort(
"{.arg truth} ({length(truth)}) and \\
{.arg case_weights} ({length(case_weights)}) must be the same length.",
.internal = TRUE
)
}

truth <- unclass(truth)
Expand Down
Loading
Loading