Skip to content

Commit

Permalink
adds eval_time as an attribute (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored Nov 29, 2023
1 parent 3509577 commit 7155350
Show file tree
Hide file tree
Showing 17 changed files with 97 additions and 36 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9000
Version: 1.1.2.9001
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export(.config_key_from_metrics)
export(.estimate_metrics)
export(.get_extra_col_names)
export(.get_fingerprint)
export(.get_tune_eval_times)
export(.get_tune_metric_names)
export(.get_tune_metrics)
export(.get_tune_outcome_names)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

* Improves documentation related to the hyperparameters associated with extracted objects that are generated from submodels. See the "Extracting with submodels" section of `?collect_extracts` to learn more.

* An `eval_time` attribute was added to tune objects. There is also a `.get_tune_eval_times()` function.

* `augment()` methods to `tune_results`, `resample_results`, and `last_fit` objects now always returns tibbles (#759).

# tune 1.1.2
Expand Down
3 changes: 3 additions & 0 deletions R/compat-vctrs-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ new_tune_results_from_template <- function(x, to) {
x = x,
parameters = attrs$parameters,
metrics = attrs$metrics,
eval_time = attrs$eval_time,
outcomes = attrs$outcomes,
rset_info = attrs$rset_info
)
Expand Down Expand Up @@ -105,6 +106,7 @@ new_resample_results_from_template <- function(x, to) {
x = x,
parameters = attrs$parameters,
metrics = attrs$metrics,
eval_time = attrs$eval_time,
outcomes = attrs$outcomes,
rset_info = attrs$rset_info
)
Expand Down Expand Up @@ -134,6 +136,7 @@ new_iteration_results_from_template <- function(x, to) {
x = x,
parameters = attrs$parameters,
metrics = attrs$metrics,
eval_time = attrs$eval_time,
outcomes = attrs$outcomes,
rset_info = attrs$rset_info,
workflow = attrs$workflow
Expand Down
33 changes: 21 additions & 12 deletions R/iteration_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,27 @@
#' @rdname empty_ellipses
#' @param parameters A `parameters` object.
#' @param metrics A metric set.
#' @param eval_time A numeric vector of time points where dynamic event time
#' metrics should be computed (e.g. the time-dependent ROC curve, etc).
#' @param outcomes A character vector of outcome names.
#' @param rset_info Attributes from an `rset` object.
#' @param workflow The workflow used to fit the iteration results.
new_iteration_results <- function(x, parameters, metrics, outcomes = character(0),
rset_info, workflow) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "iteration_results"
)
}
new_iteration_results <-
function(x,
parameters,
metrics,
eval_time,
outcomes = character(0),
rset_info,
workflow) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "iteration_results"
)
}
1 change: 1 addition & 0 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ resample_workflow <- function(workflow, resamples, metrics, control,
x = out,
parameters = attributes$parameters,
metrics = attributes$metrics,
eval_time = attributes$eval_time,
outcomes = attributes$outcomes,
rset_info = attributes$rset_info,
workflow = attributes$workflow
Expand Down
30 changes: 19 additions & 11 deletions R/resample_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@

# ------------------------------------------------------------------------------

new_resample_results <- function(x, parameters, metrics, outcomes = character(0), rset_info, workflow = NULL) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "resample_results"
)
}
new_resample_results <-
function(x,
parameters,
metrics,
eval_time,
outcomes = character(0),
rset_info,
workflow = NULL) {
new_tune_results(
x = x,
parameters = parameters,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow,
class = "resample_results"
)
}
2 changes: 2 additions & 0 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ tune_bayes_workflow <-
x = unsummarized,
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = NULL
Expand Down Expand Up @@ -476,6 +477,7 @@ tune_bayes_workflow <-
x = unsummarized,
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow_output
Expand Down
1 change: 1 addition & 0 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ tune_grid_workflow <- function(workflow,
x = resamples,
parameters = pset,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow
Expand Down
31 changes: 20 additions & 11 deletions R/tune_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,26 @@ summarize_notes <- function(x) {

# ------------------------------------------------------------------------------

new_tune_results <- function(x, parameters, metrics, outcomes = character(0), rset_info, ..., class = character()) {
new_bare_tibble(
x = x,
parameters = parameters,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
...,
class = c(class, "tune_results")
)
}
new_tune_results <-
function(x,
parameters,
metrics,
eval_time,
outcomes = character(0),
rset_info,
...,
class = character()) {
new_bare_tibble(
x = x,
parameters = parameters,
metrics = metrics,
eval_time = eval_time,
outcomes = outcomes,
rset_info = rset_info,
...,
class = c(class, "tune_results")
)
}

is_tune_results <- function(x) {
inherits(x, "tune_results")
Expand Down
14 changes: 14 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ new_bare_tibble <- function(x, ..., class = character()) {
res
}


#' @export
#' @rdname tune_accessor
.get_tune_eval_times <- function(x) {
x <- attributes(x)
if (any(names(x) == "eval_time")) {
res <- x$eval_time
} else {
res <- NULL
}
res
}


#' @export
#' @rdname tune_accessor
.get_tune_outcome_names <- function(x) {
Expand Down
4 changes: 4 additions & 0 deletions man/empty_ellipses.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/tune_accessor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/testthat/test-bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ test_that("tune recipe only", {
expect_equal(res_est$n, rep(10, iterT * 2))
expect_false(identical(num_comp, expr(tune())))
expect_true(res_workflow$trained)
expect_null(.get_tune_eval_times(res))

set.seed(1)
expect_error(
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ test_that("tune recipe only", {
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
expect_false(identical(num_comp, expr(tune())))
expect_true(res_workflow$trained)
expect_null(.get_tune_eval_times(res))
})

# ------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_that("formula method", {
nrow(predict(res$.workflow[[1]], rsample::testing(split))),
nrow(rsample::testing(split))
)

expect_null(.get_tune_eval_times(res))

})

Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test-resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ test_that("`fit_resamples()` returns a `resample_result` object", {
expect_s3_class(result, "resample_results")

expect_equal(result, .Last.tune.result)

expect_null(.get_tune_eval_times(result))
})

test_that("can use `fit_resamples()` with a formula", {
Expand Down

0 comments on commit 7155350

Please sign in to comment.