diff --git a/DESCRIPTION b/DESCRIPTION index 241e883da..6450984fb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: tune Title: Tidy Tuning Tools -Version: 1.1.2.9000 +Version: 1.1.2.9001 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), diff --git a/NAMESPACE b/NAMESPACE index 0baf6b108..4464ddd5c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -138,6 +138,7 @@ export(.config_key_from_metrics) export(.estimate_metrics) export(.get_extra_col_names) export(.get_fingerprint) +export(.get_tune_eval_times) export(.get_tune_metric_names) export(.get_tune_metrics) export(.get_tune_outcome_names) diff --git a/NEWS.md b/NEWS.md index 76375c8e0..5d34e80bd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,6 +18,8 @@ * Improves documentation related to the hyperparameters associated with extracted objects that are generated from submodels. See the "Extracting with submodels" section of `?collect_extracts` to learn more. +* An `eval_time` attribute was added to tune objects. There is also a `.get_tune_eval_times()` function. + * `augment()` methods to `tune_results`, `resample_results`, and `last_fit` objects now always returns tibbles (#759). # tune 1.1.2 diff --git a/R/compat-vctrs-helpers.R b/R/compat-vctrs-helpers.R index 99690f8ef..0ec2a5186 100644 --- a/R/compat-vctrs-helpers.R +++ b/R/compat-vctrs-helpers.R @@ -16,6 +16,7 @@ new_tune_results_from_template <- function(x, to) { x = x, parameters = attrs$parameters, metrics = attrs$metrics, + eval_time = attrs$eval_time, outcomes = attrs$outcomes, rset_info = attrs$rset_info ) @@ -105,6 +106,7 @@ new_resample_results_from_template <- function(x, to) { x = x, parameters = attrs$parameters, metrics = attrs$metrics, + eval_time = attrs$eval_time, outcomes = attrs$outcomes, rset_info = attrs$rset_info ) @@ -134,6 +136,7 @@ new_iteration_results_from_template <- function(x, to) { x = x, parameters = attrs$parameters, metrics = attrs$metrics, + eval_time = attrs$eval_time, outcomes = attrs$outcomes, rset_info = attrs$rset_info, workflow = attrs$workflow diff --git a/R/iteration_results.R b/R/iteration_results.R index cbd6dc1b1..ff12253e5 100644 --- a/R/iteration_results.R +++ b/R/iteration_results.R @@ -23,18 +23,27 @@ #' @rdname empty_ellipses #' @param parameters A `parameters` object. #' @param metrics A metric set. +#' @param eval_time A numeric vector of time points where dynamic event time +#' metrics should be computed (e.g. the time-dependent ROC curve, etc). #' @param outcomes A character vector of outcome names. #' @param rset_info Attributes from an `rset` object. #' @param workflow The workflow used to fit the iteration results. -new_iteration_results <- function(x, parameters, metrics, outcomes = character(0), - rset_info, workflow) { - new_tune_results( - x = x, - parameters = parameters, - metrics = metrics, - outcomes = outcomes, - rset_info = rset_info, - workflow = workflow, - class = "iteration_results" - ) -} +new_iteration_results <- + function(x, + parameters, + metrics, + eval_time, + outcomes = character(0), + rset_info, + workflow) { + new_tune_results( + x = x, + parameters = parameters, + metrics = metrics, + eval_time = eval_time, + outcomes = outcomes, + rset_info = rset_info, + workflow = workflow, + class = "iteration_results" + ) + } diff --git a/R/resample.R b/R/resample.R index cb9a4df9f..065d8e2c4 100644 --- a/R/resample.R +++ b/R/resample.R @@ -160,6 +160,7 @@ resample_workflow <- function(workflow, resamples, metrics, control, x = out, parameters = attributes$parameters, metrics = attributes$metrics, + eval_time = attributes$eval_time, outcomes = attributes$outcomes, rset_info = attributes$rset_info, workflow = attributes$workflow diff --git a/R/resample_results.R b/R/resample_results.R index 53e146007..4f91c36a8 100644 --- a/R/resample_results.R +++ b/R/resample_results.R @@ -14,14 +14,22 @@ # ------------------------------------------------------------------------------ -new_resample_results <- function(x, parameters, metrics, outcomes = character(0), rset_info, workflow = NULL) { - new_tune_results( - x = x, - parameters = parameters, - metrics = metrics, - outcomes = outcomes, - rset_info = rset_info, - workflow = workflow, - class = "resample_results" - ) -} +new_resample_results <- + function(x, + parameters, + metrics, + eval_time, + outcomes = character(0), + rset_info, + workflow = NULL) { + new_tune_results( + x = x, + parameters = parameters, + metrics = metrics, + eval_time = eval_time, + outcomes = outcomes, + rset_info = rset_info, + workflow = workflow, + class = "resample_results" + ) + } diff --git a/R/tune_bayes.R b/R/tune_bayes.R index b774f2216..f6e5636f3 100644 --- a/R/tune_bayes.R +++ b/R/tune_bayes.R @@ -304,6 +304,7 @@ tune_bayes_workflow <- x = unsummarized, parameters = param_info, metrics = metrics, + eval_time = eval_time, outcomes = outcomes, rset_info = rset_info, workflow = NULL @@ -476,6 +477,7 @@ tune_bayes_workflow <- x = unsummarized, parameters = param_info, metrics = metrics, + eval_time = eval_time, outcomes = outcomes, rset_info = rset_info, workflow = workflow_output diff --git a/R/tune_grid.R b/R/tune_grid.R index 1b2b89288..a558245f2 100644 --- a/R/tune_grid.R +++ b/R/tune_grid.R @@ -379,6 +379,7 @@ tune_grid_workflow <- function(workflow, x = resamples, parameters = pset, metrics = metrics, + eval_time = eval_time, outcomes = outcomes, rset_info = rset_info, workflow = workflow diff --git a/R/tune_results.R b/R/tune_results.R index 16c1ec226..ca7f0396b 100644 --- a/R/tune_results.R +++ b/R/tune_results.R @@ -97,17 +97,26 @@ summarize_notes <- function(x) { # ------------------------------------------------------------------------------ -new_tune_results <- function(x, parameters, metrics, outcomes = character(0), rset_info, ..., class = character()) { - new_bare_tibble( - x = x, - parameters = parameters, - metrics = metrics, - outcomes = outcomes, - rset_info = rset_info, - ..., - class = c(class, "tune_results") - ) -} +new_tune_results <- + function(x, + parameters, + metrics, + eval_time, + outcomes = character(0), + rset_info, + ..., + class = character()) { + new_bare_tibble( + x = x, + parameters = parameters, + metrics = metrics, + eval_time = eval_time, + outcomes = outcomes, + rset_info = rset_info, + ..., + class = c(class, "tune_results") + ) + } is_tune_results <- function(x) { inherits(x, "tune_results") diff --git a/R/utils.R b/R/utils.R index be4d180db..8822b8576 100644 --- a/R/utils.R +++ b/R/utils.R @@ -163,6 +163,20 @@ new_bare_tibble <- function(x, ..., class = character()) { res } + +#' @export +#' @rdname tune_accessor +.get_tune_eval_times <- function(x) { + x <- attributes(x) + if (any(names(x) == "eval_time")) { + res <- x$eval_time + } else { + res <- NULL + } + res +} + + #' @export #' @rdname tune_accessor .get_tune_outcome_names <- function(x) { diff --git a/man/empty_ellipses.Rd b/man/empty_ellipses.Rd index bf29678e5..88c53da9c 100644 --- a/man/empty_ellipses.Rd +++ b/man/empty_ellipses.Rd @@ -56,6 +56,7 @@ new_iteration_results( x, parameters, metrics, + eval_time, outcomes = character(0), rset_info, workflow @@ -98,6 +99,9 @@ is_workflow(x) \item{ctrl}{A \code{control_grid} object.} +\item{eval_time}{A numeric vector of time points where dynamic event time +metrics should be computed (e.g. the time-dependent ROC curve, etc).} + \item{cls}{A character vector of possible classes} \item{where}{A character string for the calling function.} diff --git a/man/tune_accessor.Rd b/man/tune_accessor.Rd index 5781f4256..432f95b83 100644 --- a/man/tune_accessor.Rd +++ b/man/tune_accessor.Rd @@ -6,6 +6,7 @@ \alias{.get_extra_col_names} \alias{.get_tune_metrics} \alias{.get_tune_metric_names} +\alias{.get_tune_eval_times} \alias{.get_tune_outcome_names} \alias{.get_tune_workflow} \alias{.get_fingerprint.tune_results} @@ -21,6 +22,8 @@ .get_tune_metric_names(x) +.get_tune_eval_times(x) + .get_tune_outcome_names(x) .get_tune_workflow(x) diff --git a/tests/testthat/test-bayes.R b/tests/testthat/test-bayes.R index a99627366..cb4f7ef4a 100644 --- a/tests/testthat/test-bayes.R +++ b/tests/testthat/test-bayes.R @@ -53,6 +53,7 @@ test_that("tune recipe only", { expect_equal(res_est$n, rep(10, iterT * 2)) expect_false(identical(num_comp, expr(tune()))) expect_true(res_workflow$trained) + expect_null(.get_tune_eval_times(res)) set.seed(1) expect_error( diff --git a/tests/testthat/test-grid.R b/tests/testthat/test-grid.R index 4302f8c6c..ddbbb70ec 100644 --- a/tests/testthat/test-grid.R +++ b/tests/testthat/test-grid.R @@ -27,6 +27,7 @@ test_that("tune recipe only", { expect_equal(res_est$n, rep(10, nrow(grid) * 2)) expect_false(identical(num_comp, expr(tune()))) expect_true(res_workflow$trained) + expect_null(.get_tune_eval_times(res)) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test-last-fit.R b/tests/testthat/test-last-fit.R index 59b2a7887..4ac976b96 100644 --- a/tests/testthat/test-last-fit.R +++ b/tests/testthat/test-last-fit.R @@ -25,7 +25,7 @@ test_that("formula method", { nrow(predict(res$.workflow[[1]], rsample::testing(split))), nrow(rsample::testing(split)) ) - + expect_null(.get_tune_eval_times(res)) }) diff --git a/tests/testthat/test-resample.R b/tests/testthat/test-resample.R index 9d96c82a2..f9cc9f02f 100644 --- a/tests/testthat/test-resample.R +++ b/tests/testthat/test-resample.R @@ -13,6 +13,8 @@ test_that("`fit_resamples()` returns a `resample_result` object", { expect_s3_class(result, "resample_results") expect_equal(result, .Last.tune.result) + + expect_null(.get_tune_eval_times(result)) }) test_that("can use `fit_resamples()` with a formula", {