Skip to content

Commit

Permalink
fix: add rush parameter to auto tuner
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed May 14, 2024
1 parent 51d592f commit 31baed8
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 5 deletions.
1 change: 1 addition & 0 deletions R/AutoTuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ AutoTuner = R6Class("AutoTuner",

ia$check_values = assert_flag(check_values)
ia$callbacks = assert_callbacks(as_callbacks(callbacks))
if (!is.null(rush)) ia$rush = assert_class(rush, "Rush")
self$instance_args = ia

super$initialize(
Expand Down
2 changes: 1 addition & 1 deletion R/ObjectiveTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ ObjectiveTuning = R6Class("ObjectiveTuning",
self$measures = assert_measures(as_measures(measures), task = self$task, learner = self$learner)
self$store_models = assert_flag(store_models)
self$store_benchmark_result = assert_flag(store_benchmark_result) || self$store_models

self$callbacks = assert_callbacks(as_callbacks(callbacks))

self$default_values = self$learner$param_set$values

super$initialize(
Expand Down
7 changes: 5 additions & 2 deletions R/auto_tuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_rush
#'
#' @export
#' @examples
Expand All @@ -45,7 +46,8 @@ auto_tuner = function(
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = NULL
callbacks = NULL,
rush = NULL
) {
terminator = terminator %??% terminator_selection(term_evals, term_time)

Expand All @@ -60,5 +62,6 @@ auto_tuner = function(
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
callbacks = callbacks,
rush = rush)
}
1 change: 0 additions & 1 deletion man/AutoTuner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/auto_tuner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test_auto_tuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,46 @@ test_that("auto_tuner function works", {
expect_class(at, "AutoTuner")
expect_class(at$instance_args$terminator, "TerminatorCombo")
})

test_that("async auto tuner works", {
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)

at = auto_tuner(
tuner = tnr("async_random_search"),
learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1)),
resampling = rsmp("cv", folds = 3),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 3)
)

expect_class(at, "AutoTuner")
at$train(tsk("pima"))

expect_class(at$tuning_instance, "TuningInstanceAsyncSingleCrit")
})

test_that("async auto tuner works with rush controller", {
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)
rush = rsh(network_id = "tuning_network")

at = auto_tuner(
tuner = tnr("async_random_search"),
learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1)),
resampling = rsmp("cv", folds = 3),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 3),
rush = rush
)

expect_class(at, "AutoTuner")
expect_class(at$instance_args$rush, "Rush")
at$train(tsk("pima"))

expect_class(at$tuning_instance, "TuningInstanceAsyncSingleCrit")
})

0 comments on commit 31baed8

Please sign in to comment.