diff --git a/R/class-sens.R b/R/class-sens.R index 8af67170..170d2620 100644 --- a/R/class-sens.R +++ b/R/class-sens.R @@ -49,7 +49,8 @@ #' #' @param case_weights The optional column identifier for case weights. #' This should be an unquoted column name that evaluates to a numeric column -#' in `data`. For `_vec()` functions, a numeric vector. +#' in `data`. For `_vec()` functions, a numeric vector, +#' [hardhat::importance_weights()], or [hardhat::frequency_weights()]. #' #' @param event_level A single string. Either `"first"` or `"second"` to specify #' which level of `truth` to consider as the "event". This argument is only diff --git a/R/num-rmse.R b/R/num-rmse.R index 6a247ed2..868384d9 100644 --- a/R/num-rmse.R +++ b/R/num-rmse.R @@ -27,7 +27,8 @@ #' #' @param case_weights The optional column identifier for case weights. This #' should be an unquoted column name that evaluates to a numeric column in -#' `data`. For `_vec()` functions, a numeric vector. +#' `data`. For `_vec()` functions, a numeric vector, +#' [hardhat::importance_weights()], or [hardhat::frequency_weights()]. #' #' @param ... Not currently used. #' diff --git a/R/prob-brier_class.R b/R/prob-brier_class.R index ffdb0fc2..41939979 100644 --- a/R/prob-brier_class.R +++ b/R/prob-brier_class.R @@ -152,5 +152,7 @@ brier_ind <- function(truth, estimate, case_weights = NULL) { brier_factor <- function(truth, estimate, case_weights = NULL) { inds <- hardhat::fct_encode_one_hot(truth) + case_weights <- vctrs::vec_cast(case_weights, to = double()) + brier_ind(inds, estimate, case_weights) } diff --git a/man/accuracy.Rd b/man/accuracy.Rd index bfab5f15..3d6ac318 100644 --- a/man/accuracy.Rd +++ b/man/accuracy.Rd @@ -35,7 +35,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/average_precision.Rd b/man/average_precision.Rd index f9dccf48..d9f40973 100644 --- a/man/average_precision.Rd +++ b/man/average_precision.Rd @@ -63,7 +63,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/bal_accuracy.Rd b/man/bal_accuracy.Rd index a6582431..e7fd8b96 100644 --- a/man/bal_accuracy.Rd +++ b/man/bal_accuracy.Rd @@ -58,7 +58,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/brier_class.Rd b/man/brier_class.Rd index 24bc3b08..56f21920 100644 --- a/man/brier_class.Rd +++ b/man/brier_class.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/brier_survival.Rd b/man/brier_survival.Rd index 34b4ae1e..fba33ef0 100644 --- a/man/brier_survival.Rd +++ b/man/brier_survival.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/brier_survival_integrated.Rd b/man/brier_survival_integrated.Rd index af2b8220..e6c6aa44 100644 --- a/man/brier_survival_integrated.Rd +++ b/man/brier_survival_integrated.Rd @@ -40,7 +40,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/ccc.Rd b/man/ccc.Rd index aef92e89..5306edb1 100644 --- a/man/ccc.Rd +++ b/man/ccc.Rd @@ -45,7 +45,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/classification_cost.Rd b/man/classification_cost.Rd index 6f79d7b3..67bfe87b 100644 --- a/man/classification_cost.Rd +++ b/man/classification_cost.Rd @@ -75,7 +75,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/concordance_survival.Rd b/man/concordance_survival.Rd index 967a2303..426b789a 100644 --- a/man/concordance_survival.Rd +++ b/man/concordance_survival.Rd @@ -48,7 +48,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/conf_mat.Rd b/man/conf_mat.Rd index a1d0c44e..490555cc 100644 --- a/man/conf_mat.Rd +++ b/man/conf_mat.Rd @@ -41,7 +41,8 @@ unquoted variable name. For \verb{_vec()} functions, a \code{factor} vector.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{x}{A \code{conf_mat} object.} } diff --git a/man/detection_prevalence.Rd b/man/detection_prevalence.Rd index eff29d49..d8d86301 100644 --- a/man/detection_prevalence.Rd +++ b/man/detection_prevalence.Rd @@ -58,7 +58,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/f_meas.Rd b/man/f_meas.Rd index 6858fe86..267f2cd8 100644 --- a/man/f_meas.Rd +++ b/man/f_meas.Rd @@ -65,7 +65,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/gain_capture.Rd b/man/gain_capture.Rd index 4c9c9993..455c1ca2 100644 --- a/man/gain_capture.Rd +++ b/man/gain_capture.Rd @@ -63,7 +63,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/gain_curve.Rd b/man/gain_curve.Rd index de31636f..d9f8e467 100644 --- a/man/gain_curve.Rd +++ b/man/gain_curve.Rd @@ -45,7 +45,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A tibble with class \code{gain_df} or \code{gain_grouped_df} having columns: diff --git a/man/huber_loss.Rd b/man/huber_loss.Rd index 0fe775d7..7ab42eb9 100644 --- a/man/huber_loss.Rd +++ b/man/huber_loss.Rd @@ -52,7 +52,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/huber_loss_pseudo.Rd b/man/huber_loss_pseudo.Rd index f62847a1..13ef94b3 100644 --- a/man/huber_loss_pseudo.Rd +++ b/man/huber_loss_pseudo.Rd @@ -52,7 +52,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/iic.Rd b/man/iic.Rd index 4879a00d..530bf8bd 100644 --- a/man/iic.Rd +++ b/man/iic.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/j_index.Rd b/man/j_index.Rd index 516d4078..958134d9 100644 --- a/man/j_index.Rd +++ b/man/j_index.Rd @@ -58,7 +58,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/kap.Rd b/man/kap.Rd index cbc9b954..fc70843d 100644 --- a/man/kap.Rd +++ b/man/kap.Rd @@ -60,7 +60,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/lift_curve.Rd b/man/lift_curve.Rd index 63d616c7..f056e45a 100644 --- a/man/lift_curve.Rd +++ b/man/lift_curve.Rd @@ -45,7 +45,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A tibble with class \code{lift_df} or \code{lift_grouped_df} having diff --git a/man/mae.Rd b/man/mae.Rd index 531831a5..1059a945 100644 --- a/man/mae.Rd +++ b/man/mae.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/mape.Rd b/man/mape.Rd index 9365d871..52f4400b 100644 --- a/man/mape.Rd +++ b/man/mape.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/mase.Rd b/man/mase.Rd index 09380de9..b6d756b4 100644 --- a/man/mase.Rd +++ b/man/mase.Rd @@ -62,7 +62,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/mcc.Rd b/man/mcc.Rd index 1948b72b..4834ace7 100644 --- a/man/mcc.Rd +++ b/man/mcc.Rd @@ -35,7 +35,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/mn_log_loss.Rd b/man/mn_log_loss.Rd index 823cd2e9..de2011fb 100644 --- a/man/mn_log_loss.Rd +++ b/man/mn_log_loss.Rd @@ -60,7 +60,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/mpe.Rd b/man/mpe.Rd index 24f28f57..a5f295fa 100644 --- a/man/mpe.Rd +++ b/man/mpe.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/msd.Rd b/man/msd.Rd index bcd00392..adb120d9 100644 --- a/man/msd.Rd +++ b/man/msd.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/npv.Rd b/man/npv.Rd index 1b0ccba9..1c5d2e8b 100644 --- a/man/npv.Rd +++ b/man/npv.Rd @@ -63,7 +63,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/poisson_log_loss.Rd b/man/poisson_log_loss.Rd index e4ea6183..5301310b 100644 --- a/man/poisson_log_loss.Rd +++ b/man/poisson_log_loss.Rd @@ -33,7 +33,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/ppv.Rd b/man/ppv.Rd index 961e3bbc..198b909d 100644 --- a/man/ppv.Rd +++ b/man/ppv.Rd @@ -63,7 +63,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/pr_auc.Rd b/man/pr_auc.Rd index f2ceca39..cea04223 100644 --- a/man/pr_auc.Rd +++ b/man/pr_auc.Rd @@ -63,7 +63,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/pr_curve.Rd b/man/pr_curve.Rd index fc5f2240..14704c45 100644 --- a/man/pr_curve.Rd +++ b/man/pr_curve.Rd @@ -45,7 +45,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A tibble with class \code{pr_df} or \code{pr_grouped_df} having diff --git a/man/precision.Rd b/man/precision.Rd index 2de3bf41..e7728608 100644 --- a/man/precision.Rd +++ b/man/precision.Rd @@ -58,7 +58,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/recall.Rd b/man/recall.Rd index 4c8e7de4..2c8c57bc 100644 --- a/man/recall.Rd +++ b/man/recall.Rd @@ -58,7 +58,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/rmse.Rd b/man/rmse.Rd index 5f848e83..d62fb43f 100644 --- a/man/rmse.Rd +++ b/man/rmse.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/roc_auc.Rd b/man/roc_auc.Rd index 1a37a4ed..2de2a576 100644 --- a/man/roc_auc.Rd +++ b/man/roc_auc.Rd @@ -68,7 +68,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{options}{\verb{[deprecated]} diff --git a/man/roc_auc_survival.Rd b/man/roc_auc_survival.Rd index 039fe2e1..f9310db4 100644 --- a/man/roc_auc_survival.Rd +++ b/man/roc_auc_survival.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{estimate}{If \code{truth} is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many diff --git a/man/roc_aunp.Rd b/man/roc_aunp.Rd index 8c42f4bf..3fc72a94 100644 --- a/man/roc_aunp.Rd +++ b/man/roc_aunp.Rd @@ -39,7 +39,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{options}{\verb{[deprecated]} diff --git a/man/roc_aunu.Rd b/man/roc_aunu.Rd index eabd4428..c9c54754 100644 --- a/man/roc_aunu.Rd +++ b/man/roc_aunu.Rd @@ -39,7 +39,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{options}{\verb{[deprecated]} diff --git a/man/roc_curve.Rd b/man/roc_curve.Rd index 9226a46c..4ff27054 100644 --- a/man/roc_curve.Rd +++ b/man/roc_curve.Rd @@ -46,7 +46,8 @@ used instead with a warning.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{options}{\verb{[deprecated]} diff --git a/man/roc_curve_survival.Rd b/man/roc_curve_survival.Rd index 5d63877c..cd8f0f02 100644 --- a/man/roc_curve_survival.Rd +++ b/man/roc_curve_survival.Rd @@ -31,7 +31,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A tibble with class \code{roc_survival_df}, \code{grouped_roc_survival_df} having diff --git a/man/rpd.Rd b/man/rpd.Rd index f2c2f2c4..251ce5aa 100644 --- a/man/rpd.Rd +++ b/man/rpd.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/rpiq.Rd b/man/rpiq.Rd index 98952ee0..987b67c1 100644 --- a/man/rpiq.Rd +++ b/man/rpiq.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/rsq.Rd b/man/rsq.Rd index 05f23c1d..1e88937b 100644 --- a/man/rsq.Rd +++ b/man/rsq.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/rsq_trad.Rd b/man/rsq_trad.Rd index 8e581ce7..2cc7c9cb 100644 --- a/man/rsq_trad.Rd +++ b/man/rsq_trad.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/sens.Rd b/man/sens.Rd index 70cc4023..4f93ddb8 100644 --- a/man/sens.Rd +++ b/man/sens.Rd @@ -84,7 +84,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/man/smape.Rd b/man/smape.Rd index 21f0ebb7..313bec13 100644 --- a/man/smape.Rd +++ b/man/smape.Rd @@ -34,7 +34,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column in -\code{data}. For \verb{_vec()} functions, a numeric vector.} +\code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} } \value{ A \code{tibble} with columns \code{.metric}, \code{.estimator}, diff --git a/man/spec.Rd b/man/spec.Rd index 3beb476f..a2f93b70 100644 --- a/man/spec.Rd +++ b/man/spec.Rd @@ -84,7 +84,8 @@ values should be stripped before the computation proceeds.} \item{case_weights}{The optional column identifier for case weights. This should be an unquoted column name that evaluates to a numeric column -in \code{data}. For \verb{_vec()} functions, a numeric vector.} +in \code{data}. For \verb{_vec()} functions, a numeric vector, +\code{\link[hardhat:importance_weights]{hardhat::importance_weights()}}, or \code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}}.} \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify which level of \code{truth} to consider as the "event". This argument is only diff --git a/tests/testthat/test-class-accuracy.R b/tests/testthat/test-class-accuracy.R index 90adc898..b9cc97b6 100644 --- a/tests/testthat/test-class-accuracy.R +++ b/tests/testthat/test-class-accuracy.R @@ -58,6 +58,21 @@ test_that("two class with case weights is correct", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + accuracy_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + accuracy_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-bal_accuracy.R b/tests/testthat/test-class-bal_accuracy.R index 63dbcbd6..22cd943e 100644 --- a/tests/testthat/test-class-bal_accuracy.R +++ b/tests/testthat/test-class-bal_accuracy.R @@ -17,6 +17,21 @@ test_that("Two class", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + bal_accuracy_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + bal_accuracy_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("`event_level = 'second'` should be identical to 'first'", { lst <- data_altman() pathology <- lst$pathology diff --git a/tests/testthat/test-class-detection_prevalence.R b/tests/testthat/test-class-detection_prevalence.R index 056c5034..6e1ac8eb 100644 --- a/tests/testthat/test-class-detection_prevalence.R +++ b/tests/testthat/test-class-detection_prevalence.R @@ -63,6 +63,21 @@ test_that("two class with case weights is correct", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + detection_prevalence_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + detection_prevalence_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-f_meas.R b/tests/testthat/test-class-f_meas.R index 98353c50..27468e0d 100644 --- a/tests/testthat/test-class-f_meas.R +++ b/tests/testthat/test-class-f_meas.R @@ -121,6 +121,21 @@ test_that("`NA` is still returned if there are some undefined recall values but expect_warning(f_meas_vec(truth, estimate, na_rm = FALSE), NA) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + f_meas_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + f_meas_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-j_index.R b/tests/testthat/test-class-j_index.R index 88473899..65a124b7 100644 --- a/tests/testthat/test-class-j_index.R +++ b/tests/testthat/test-class-j_index.R @@ -128,6 +128,21 @@ test_that("`NA` is still returned if there are some undefined sensitivity values expect_warning(j_index_vec(truth, estimate, na_rm = FALSE), NA) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + j_index_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + j_index_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-kap.R b/tests/testthat/test-class-kap.R index fee097a8..8d15ef73 100644 --- a/tests/testthat/test-class-kap.R +++ b/tests/testthat/test-class-kap.R @@ -27,6 +27,21 @@ test_that("kap errors with wrong `weighting`", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + kap_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + kap_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-mcc.R b/tests/testthat/test-class-mcc.R index ab657718..47758510 100644 --- a/tests/testthat/test-class-mcc.R +++ b/tests/testthat/test-class-mcc.R @@ -39,6 +39,21 @@ test_that("doesn't integer overflow (#108)", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + mcc_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + mcc_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-npv.R b/tests/testthat/test-class-npv.R index a890349b..a1bcf003 100644 --- a/tests/testthat/test-class-npv.R +++ b/tests/testthat/test-class-npv.R @@ -73,6 +73,21 @@ test_that("Three class", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + npv_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + npv_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-ppv.R b/tests/testthat/test-class-ppv.R index a4ea4e5a..8ff08d18 100644 --- a/tests/testthat/test-class-ppv.R +++ b/tests/testthat/test-class-ppv.R @@ -118,6 +118,21 @@ test_that("Multi class weighted - sklearn equivalent", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + ppv_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + ppv_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-precision.R b/tests/testthat/test-class-precision.R index 87fd9741..8aee89a1 100644 --- a/tests/testthat/test-class-precision.R +++ b/tests/testthat/test-class-precision.R @@ -76,6 +76,21 @@ test_that("`NA` is still returned if there are some undefined precision values b expect_warning(precision_vec(truth, estimate, na_rm = FALSE), NA) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + precision_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + precision_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-recall.R b/tests/testthat/test-class-recall.R index 5a3a2d7a..a9b0780a 100644 --- a/tests/testthat/test-class-recall.R +++ b/tests/testthat/test-class-recall.R @@ -64,6 +64,21 @@ test_that("`NA` is still returned if there are some undefined recall values but expect_warning(recall_vec(truth, estimate, na_rm = FALSE), NA) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + recall_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + recall_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-sens.R b/tests/testthat/test-class-sens.R index 5d997565..f8820ad7 100644 --- a/tests/testthat/test-class-sens.R +++ b/tests/testthat/test-class-sens.R @@ -194,6 +194,21 @@ test_that("`sensitivity()` has a metric name unique to it (#232)", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + sensitivity_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + sensitivity_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-class-spec.R b/tests/testthat/test-class-spec.R index f9a41200..ca01d7f2 100644 --- a/tests/testthat/test-class-spec.R +++ b/tests/testthat/test-class-spec.R @@ -134,6 +134,21 @@ test_that("`specificity()` has a metric name unique to it (#232)", { ) }) +test_that("works with hardhat case weights", { + lst <- data_altman() + df <- lst$pathology + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + specificity_vec(df$pathology, df$scan, case_weights = imp_wgt) + ) + + expect_no_error( + specificity_vec(df$pathology, df$scan, case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-num-ccc.R b/tests/testthat/test-num-ccc.R index 37916627..473c7305 100644 --- a/tests/testthat/test-num-ccc.R +++ b/tests/testthat/test-num-ccc.R @@ -66,3 +66,19 @@ test_that("can use hardhat case weights", { expect_identical(result1, expect1) expect_identical(result2, expect2) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + ccc_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + ccc_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-huber_loss.R b/tests/testthat/test-num-huber_loss.R index e9530ce1..46a69cc7 100644 --- a/tests/testthat/test-num-huber_loss.R +++ b/tests/testthat/test-num-huber_loss.R @@ -51,3 +51,19 @@ test_that("Weighted results are working", { 3.5 / 4 ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + huber_loss_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + huber_loss_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-iic.R b/tests/testthat/test-num-iic.R index 506ea28c..9df4ffa7 100644 --- a/tests/testthat/test-num-iic.R +++ b/tests/testthat/test-num-iic.R @@ -35,3 +35,19 @@ test_that("yardstick correlation warnings are thrown", { cnd <- rlang::catch_cnd(iic_vec(c(1, 1), c(1, 2))) expect_s3_class(cnd, "yardstick_warning_correlation_undefined_constant_truth") }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + iic_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + iic_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-mae.R b/tests/testthat/test-num-mae.R index 8cc353ee..fb26312b 100644 --- a/tests/testthat/test-num-mae.R +++ b/tests/testthat/test-num-mae.R @@ -20,3 +20,19 @@ test_that("Weighted results are the same as scikit-learn", { read_pydata("py-mae")$case_weight ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + mae_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + mae_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-mape.R b/tests/testthat/test-num-mape.R index eec3e464..1059b355 100644 --- a/tests/testthat/test-num-mape.R +++ b/tests/testthat/test-num-mape.R @@ -39,3 +39,19 @@ test_that("Weighted results are the same as scikit-learn", { read_pydata("py-mape")$case_weight * 100 ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + mape_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + mape_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-mase.R b/tests/testthat/test-num-mase.R index ea5b4437..999a0983 100644 --- a/tests/testthat/test-num-mase.R +++ b/tests/testthat/test-num-mase.R @@ -75,3 +75,19 @@ test_that("Weighted results are working", { 5 / 4 ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + mape_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + mape_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-mpe.R b/tests/testthat/test-num-mpe.R index c876f0c5..eea2643a 100644 --- a/tests/testthat/test-num-mpe.R +++ b/tests/testthat/test-num-mpe.R @@ -44,3 +44,19 @@ test_that("Weighted results are working", { -3 / 4 * 100 ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + mpe_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + mpe_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-msd.R b/tests/testthat/test-num-msd.R index 5e29cca8..1d1ad60d 100644 --- a/tests/testthat/test-num-msd.R +++ b/tests/testthat/test-num-msd.R @@ -36,3 +36,19 @@ test_that("weighted results are correct", { -4 / 3 ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + msd_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + msd_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-poisson_log_loss.R b/tests/testthat/test-num-poisson_log_loss.R index 52ec25ea..bcd3a446 100644 --- a/tests/testthat/test-num-poisson_log_loss.R +++ b/tests/testthat/test-num-poisson_log_loss.R @@ -28,3 +28,21 @@ test_that("weighted results are working", { yardstick_mean(-stats::dpois(count_results$count, count_results$pred, log = TRUE), case_weights = count_results$weights) ) }) + +test_that("works with hardhat case weights", { + count_results <- data_counts()$basic + count_results$weights <- c(1, 2, 1, 1, 2, 1) + + df <- count_results + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + poisson_log_loss_vec(df$count, df$pred, case_weights = imp_wgt) + ) + + expect_no_error( + poisson_log_loss_vec(df$count, df$pred, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-pseudo_huber_loss.R b/tests/testthat/test-num-pseudo_huber_loss.R index 0e4f1c42..cc8e5c21 100644 --- a/tests/testthat/test-num-pseudo_huber_loss.R +++ b/tests/testthat/test-num-pseudo_huber_loss.R @@ -39,3 +39,19 @@ test_that("Weighted results are working", { yardstick_mean(sqrt(1 + (truth - estimate)^2) - 1, case_weights = weights) ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + huber_loss_pseudo_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + huber_loss_pseudo_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-rmse.R b/tests/testthat/test-num-rmse.R index cce8a6e1..507f0d4b 100644 --- a/tests/testthat/test-num-rmse.R +++ b/tests/testthat/test-num-rmse.R @@ -30,3 +30,19 @@ test_that("Integer columns are allowed (#44)", { sqrt(mean((ex_dat$obs - ex_dat$pred)^2)) ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + rmse_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + rmse_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-rpd.R b/tests/testthat/test-num-rpd.R index 947bc382..7df4dd91 100644 --- a/tests/testthat/test-num-rpd.R +++ b/tests/testthat/test-num-rpd.R @@ -24,3 +24,19 @@ test_that("case weights are applied", { } ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + rpd_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + rpd_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-rpiq.R b/tests/testthat/test-num-rpiq.R index 4a80212c..3acf6241 100644 --- a/tests/testthat/test-num-rpiq.R +++ b/tests/testthat/test-num-rpiq.R @@ -21,3 +21,21 @@ test_that("case weights are applied", { 3.401406885440771965534 ) }) + +test_that("works with hardhat case weights", { + count_results <- data_counts()$basic + count_results$weights <- c(1, 2, 1, 1, 2, 1) + + df <- count_results + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + rpiq_vec(df$count, df$pred, case_weights = imp_wgt) + ) + + expect_no_error( + rpiq_vec(df$count, df$pred, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-rsq.R b/tests/testthat/test-num-rsq.R index a994a948..781af5f2 100644 --- a/tests/testthat/test-num-rsq.R +++ b/tests/testthat/test-num-rsq.R @@ -65,3 +65,19 @@ test_that("yardstick correlation warnings are thrown", { }) expect_identical(out, NA_real_) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + rsq_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + rsq_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-rsq_trad.R b/tests/testthat/test-num-rsq_trad.R index d9dd6c7c..b26c5ae0 100644 --- a/tests/testthat/test-num-rsq_trad.R +++ b/tests/testthat/test-num-rsq_trad.R @@ -44,3 +44,19 @@ test_that("Weighted results are the same as scikit-learn", { read_pydata("py-rsq-trad")$case_weight ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + rsq_trad_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + rsq_trad_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-num-smape.R b/tests/testthat/test-num-smape.R index 8a0e19ca..4a6da166 100644 --- a/tests/testthat/test-num-smape.R +++ b/tests/testthat/test-num-smape.R @@ -22,3 +22,19 @@ test_that("Weighted results are working", { 50 ) }) + +test_that("works with hardhat case weights", { + solubility_test$weights <- floor(read_weights_solubility_test()) + df <- solubility_test + + imp_wgt <- hardhat::importance_weights(df$weights) + freq_wgt <- hardhat::frequency_weights(df$weights) + + expect_no_error( + smape_vec(df$solubility, df$prediction, case_weights = imp_wgt) + ) + + expect_no_error( + smape_vec(df$solubility, df$prediction, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-prob-average_precision.R b/tests/testthat/test-prob-average_precision.R index 015c3ab2..56ffb17e 100644 --- a/tests/testthat/test-prob-average_precision.R +++ b/tests/testthat/test-prob-average_precision.R @@ -92,6 +92,21 @@ test_that("Multiclass weighted average precision matches sklearn", { ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + average_precision_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + average_precision_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-brier_class.R b/tests/testthat/test-prob-brier_class.R index 0bd90b13..a9c472de 100644 --- a/tests/testthat/test-prob-brier_class.R +++ b/tests/testthat/test-prob-brier_class.R @@ -67,6 +67,21 @@ test_that("basic results", { ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + brier_class_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + brier_class_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-classification_cost.R b/tests/testthat/test-prob-classification_cost.R index fdf86ad0..ca3c1dea 100644 --- a/tests/testthat/test-prob-classification_cost.R +++ b/tests/testthat/test-prob-classification_cost.R @@ -347,6 +347,21 @@ test_that("multiclass - uses case weights", { ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + classification_cost_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + classification_cost_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-gain_capture.R b/tests/testthat/test-prob-gain_capture.R index e3d0c0f6..373c110a 100644 --- a/tests/testthat/test-prob-gain_capture.R +++ b/tests/testthat/test-prob-gain_capture.R @@ -133,6 +133,21 @@ test_that("multiclass macro / macro_weighted - case weights are applied correctl ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + gain_capture_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + gain_capture_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-mn_log_loss.R b/tests/testthat/test-prob-mn_log_loss.R index cec48242..6725ff60 100644 --- a/tests/testthat/test-prob-mn_log_loss.R +++ b/tests/testthat/test-prob-mn_log_loss.R @@ -88,6 +88,21 @@ test_that("mn_log_loss() applies the min/max rule when a 'non-event' has probabi ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + mn_log_loss_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + mn_log_loss_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-pr_auc.R b/tests/testthat/test-prob-pr_auc.R index 6360a351..9ebf407a 100644 --- a/tests/testthat/test-prob-pr_auc.R +++ b/tests/testthat/test-prob-pr_auc.R @@ -75,6 +75,21 @@ test_that("grouped multiclass (one-vs-all) weighted example matches expanded equ ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + pr_auc_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + pr_auc_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-roc_auc.R b/tests/testthat/test-prob-roc_auc.R index d9840514..e9d398a0 100644 --- a/tests/testthat/test-prob-roc_auc.R +++ b/tests/testthat/test-prob-roc_auc.R @@ -382,6 +382,21 @@ test_that("roc_auc() - `options` is deprecated", { ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + roc_auc_vec(df$truth, df$Class1, case_weights = imp_wgt) + ) + + expect_no_error( + roc_auc_vec(df$truth, df$Class1, case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-roc_aunp.R b/tests/testthat/test-prob-roc_aunp.R index f222cf8e..d396a775 100644 --- a/tests/testthat/test-prob-roc_aunp.R +++ b/tests/testthat/test-prob-roc_aunp.R @@ -67,6 +67,21 @@ test_that("roc_aunp() - `options` is deprecated", { ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + roc_aunp_vec(df$truth, as.matrix(df[c("Class1", "Class2")]), case_weights = imp_wgt) + ) + + expect_no_error( + roc_aunp_vec(df$truth, as.matrix(df[c("Class1", "Class2")]), case_weights = freq_wgt) + ) +}) + test_that("work with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-prob-roc_aunu.R b/tests/testthat/test-prob-roc_aunu.R index 9c43dc90..6e0bd068 100644 --- a/tests/testthat/test-prob-roc_aunu.R +++ b/tests/testthat/test-prob-roc_aunu.R @@ -67,6 +67,21 @@ test_that("roc_aunu() - `options` is deprecated", { ) }) +test_that("works with hardhat case weights", { + df <- two_class_example + + imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) + freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) + + expect_no_error( + roc_aunu_vec(df$truth, as.matrix(df[c("Class1", "Class2")]), case_weights = imp_wgt) + ) + + expect_no_error( + roc_aunu_vec(df$truth, as.matrix(df[c("Class1", "Class2")]), case_weights = freq_wgt) + ) +}) + test_that("errors with class_pred input", { skip_if_not_installed("probably") diff --git a/tests/testthat/test-surv-brier_survival.R b/tests/testthat/test-surv-brier_survival.R index 14e1e699..25b641cf 100644 --- a/tests/testthat/test-surv-brier_survival.R +++ b/tests/testthat/test-surv-brier_survival.R @@ -36,6 +36,24 @@ test_that("case weights", { ) }) +test_that("works with hardhat case weights", { + lung_surv <- data_lung_surv() + lung_surv$case_wts <- rep(2, nrow(lung_surv)) + + df <- lung_surv + + df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) + df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) + + expect_no_error( + brier_survival(df, truth = surv_obj, .pred, case_weights = imp_wgt) + ) + + expect_no_error( + brier_survival(df, truth = surv_obj, .pred, case_weights = freq_wgt) + ) +}) + # riskRegression compare ------------------------------------------------------- test_that("riskRegression equivalent", { diff --git a/tests/testthat/test-surv-brier_survival_integrated.R b/tests/testthat/test-surv-brier_survival_integrated.R index 4d7be932..564ae1ac 100644 --- a/tests/testthat/test-surv-brier_survival_integrated.R +++ b/tests/testthat/test-surv-brier_survival_integrated.R @@ -63,3 +63,22 @@ test_that("case weights", { brier_integrated_res$.estimate ) }) + + +test_that("works with hardhat case weights", { + lung_surv <- data_lung_surv() + lung_surv$case_wts <- rep(2, nrow(lung_surv)) + + df <- lung_surv + + df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) + df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) + + expect_no_error( + brier_survival_integrated(df, truth = surv_obj, .pred, case_weights = imp_wgt) + ) + + expect_no_error( + brier_survival_integrated(df, truth = surv_obj, .pred, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-surv-concordance_survival.R b/tests/testthat/test-surv-concordance_survival.R index 3020e852..3011e6eb 100644 --- a/tests/testthat/test-surv-concordance_survival.R +++ b/tests/testthat/test-surv-concordance_survival.R @@ -59,3 +59,21 @@ test_that("works with infinite time predictions", { expect_true(!identical(res, exp_res)) }) + +test_that("works with hardhat case weights", { + lung_surv <- data_lung_surv() + lung_surv$case_wts <- rep(2, nrow(lung_surv)) + + df <- lung_surv + + df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) + df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) + + expect_no_error( + concordance_survival(df, truth = surv_obj, .pred_time, case_weights = imp_wgt) + ) + + expect_no_error( + concordance_survival(df, truth = surv_obj, .pred_time, case_weights = freq_wgt) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-surv-roc_auc_survival.R b/tests/testthat/test-surv-roc_auc_survival.R index afb6454a..e082b6a9 100644 --- a/tests/testthat/test-surv-roc_auc_survival.R +++ b/tests/testthat/test-surv-roc_auc_survival.R @@ -42,6 +42,24 @@ test_that("case weights are applied", { expect_identical(subset_res, wts_res) }) +test_that("works with hardhat case weights", { + lung_surv <- data_lung_surv() + lung_surv$case_wts <- rep(2, nrow(lung_surv)) + + df <- lung_surv + + df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) + df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) + + expect_no_error( + roc_auc_survival(df, truth = surv_obj, .pred, case_weights = imp_wgt) + ) + + expect_no_error( + roc_auc_survival(df, truth = surv_obj, .pred, case_weights = freq_wgt) + ) +}) + # self checking ---------------------------------------------------------------- test_that("snapshot equivalent", {