diff --git a/R/AutoTuner.R b/R/AutoTuner.R index 3f7da6f2..33d3771f 100644 --- a/R/AutoTuner.R +++ b/R/AutoTuner.R @@ -159,6 +159,7 @@ AutoTuner = R6Class("AutoTuner", ia$check_values = assert_flag(check_values) ia$callbacks = assert_callbacks(as_callbacks(callbacks)) + if (!is.null(rush)) ia$rush = assert_class(rush, "Rush") self$instance_args = ia super$initialize( diff --git a/R/ObjectiveTuning.R b/R/ObjectiveTuning.R index 0d623d0c..ce47d7ac 100644 --- a/R/ObjectiveTuning.R +++ b/R/ObjectiveTuning.R @@ -60,8 +60,8 @@ ObjectiveTuning = R6Class("ObjectiveTuning", self$measures = assert_measures(as_measures(measures), task = self$task, learner = self$learner) self$store_models = assert_flag(store_models) self$store_benchmark_result = assert_flag(store_benchmark_result) || self$store_models - self$callbacks = assert_callbacks(as_callbacks(callbacks)) + self$default_values = self$learner$param_set$values super$initialize( diff --git a/R/auto_tuner.R b/R/auto_tuner.R index 910913cf..8bcf0c03 100644 --- a/R/auto_tuner.R +++ b/R/auto_tuner.R @@ -21,6 +21,7 @@ #' @template param_store_models #' @template param_check_values #' @template param_callbacks +#' @template param_rush #' #' @export #' @examples @@ -45,7 +46,8 @@ auto_tuner = function( store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, - callbacks = NULL + callbacks = NULL, + rush = NULL ) { terminator = terminator %??% terminator_selection(term_evals, term_time) @@ -60,5 +62,6 @@ auto_tuner = function( store_benchmark_result = store_benchmark_result, store_models = store_models, check_values = check_values, - callbacks = callbacks) + callbacks = callbacks, + rush = rush) } diff --git a/man/AutoTuner.Rd b/man/AutoTuner.Rd index 6cc2d9e7..dde825d2 100644 --- a/man/AutoTuner.Rd +++ b/man/AutoTuner.Rd @@ -176,7 +176,6 @@ Hash (unique identifier) for this partial object, excluding some components whic \if{html}{\out{
Inherited methods