diff --git a/DESCRIPTION b/DESCRIPTION index 6c79cb3a..621fe625 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -61,7 +61,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 VignetteBuilder: knitr Collate: 'Condition.R' diff --git a/NAMESPACE b/NAMESPACE index b3eefa14..71371801 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -49,6 +49,7 @@ S3method(format,Condition) S3method(print,Condition) S3method(print,Domain) S3method(print,FullTuneToken) +S3method(print,InternalTuneToken) S3method(print,ObjectTuneToken) S3method(print,RangeTuneToken) S3method(rd_info,ParamSet) diff --git a/NEWS.md b/NEWS.md index 79a119c3..6ce734e4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,5 @@ -# paradox 0.12.0 +# paradox 1.0.0 + * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: * `ParamSet` now supports `extra_trafo` natively; it behaves like `.extra_trafo` of the `ps()` call. * `ParamSet` has `$constraint` @@ -7,6 +8,7 @@ * `Condition` objects are now S3 objects and can be constructed with `CondEqual()` and `CondAnyOf()`, instead of `CondXyz$new()`. (It is recommended to use the `Domain` interface for conditions, which has not changed) * `ParamSet` has new fields `$is_logscale`, `$has_trafo_param` (per-param), and `$has_trafo_param` (scalar for the whole set). * Added a vignette which was previously a chapter in the `mlr3book` +* feat: added support for `InternalTuneToken`s # paradox 0.11.1 diff --git a/R/Domain.R b/R/Domain.R index 56b98a55..b7dee1e8 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -63,6 +63,17 @@ #' @param init (`any`)\cr #' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this #' value upon construction. +#' @param aggr (`function`)\cr +#' Default aggregation function for a parameter. Can only be given for parameters tagged with `"internal_tuning"`. +#' Function with one argument, which is a list of parameter values and that returns the aggregated parameter value. +#' @param in_tune_fn (`function(domain, param_vals)`)\cr +#' Function that converters a `Domain` object into a parameter value. +#' Can only be given for parameters tagged with `"internal_tuning"`. +#' This function should also assert that the parameters required to enable internal tuning for the given `domain` are +#' set in `param_vals` (such as `early_stopping_rounds` for `XGBoost`). +#' @param disable_in_tune (named `list()`)\cr +#' The parameter values that need to be set in the `ParamSet` to disable the internal tuning for the parameter. +#' For `XGBoost` this would e.g. be `list(early_stopping_rounds = NULL)`. #' #' @return A `Domain` object. #' @@ -117,6 +128,30 @@ #' # ... but get transformed to integers. #' print(grid$transpose()) #' +#' +#' # internal tuning +#' +#' param_set = ps( +#' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), +#' in_tune_fn = function(domain, param_vals) { +#' stopifnot(domain$lower <= 1) +#' stopifnot(param_vals$early_stopping == TRUE) +#' domain$upper +#' }, +#' disable_in_tune = list(early_stopping = FALSE)), +#' early_stopping = p_lgl() +#' ) +#' param_set$set_values( +#' iters = to_tune(upper = 100, internal = TRUE), +#' early_stopping = TRUE +#' ) +#' param_set$convert_internal_search_space(param_set$search_space()) +#' param_set$aggr_internal_tuned_values( +#' list(iters = list(1, 2, 3)) +#' ) +#' +#' param_set$disable_internal_tuning("iters") +#' param_set$values$early_stopping #' @family ParamSet construction helpers #' @name Domain NULL @@ -136,6 +171,21 @@ Domain = function(cls, grouping, storage_type = "list", init) { + if ("internal_tuning" %in% tags) { + assert_true(!is.null(cargo$aggr), .var.name = "aggregation function exists") + } + assert_list(cargo$disable_in_tune, null.ok = TRUE, names = "unique") + assert_function(cargo$aggr, null.ok = TRUE) + assert_function(cargo$in_tune_fn, null.ok = TRUE) + if ((!is.null(cargo$in_tune_fn) || !is.null(cargo$disable_in_tune)) && "internal_tuning" %nin% tags) { + # we cannot check the reverse, as parameters in the search space can be tagged with 'internal_tuning' + # and not provide in_tune_fn or disable_in_tune + stopf("Arguments in_tune_fn and disable_in_tune require the tag 'internal_tuning' to be present.") + } + if ((is.null(cargo$in_tune_fn) + is.null(cargo$disable_in_tune)) == 1) { + stopf("Arguments in_tune_fn and disable_tune_fn must both be present") + } + assert_string(cls) assert_string(grouping) assert_number(lower, na.ok = TRUE) @@ -227,7 +277,6 @@ print.Domain = function(x, ...) { if (!is.null(repr)) { print(repr) } else { - plural_rows = classes = class(x) if ("Domain" %in% classes) { domainidx = which("Domain" == classes)[[1]] diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 73000591..6b355bf5 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -1,6 +1,6 @@ #' @rdname Domain #' @export -p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) { +p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_number(tolerance, lower = 0) assert_number(lower) assert_number(upper) @@ -17,8 +17,13 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ real_upper = upper } + cargo = list() + if (logscale) cargo$logscale = TRUE + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric", - depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale") + depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamFct.R b/R/ParamFct.R index cad930db..3adc69c7 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -1,6 +1,7 @@ #' @rdname Domain #' @export -p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) { +p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { + assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) if (!is.character(levels)) { @@ -21,8 +22,14 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact } # group p_fct by levels, so the group can be checked in a vectorized fashion. # We escape '"' and '\' to '\"' and '\\', respectively. + cargo = list() + cargo$disable_in_tune = disable_in_tune + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",") - Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init) + Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, + default = default, tags = tags, trafo = trafo, storage_type = "character", + depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamInt.R b/R/ParamInt.R index aeed7f1a..c98a5ea3 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -1,7 +1,7 @@ #' @rdname Domain #' @export -p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) { +p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_number(tolerance, lower = 0, upper = 0.5) # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) @@ -23,9 +23,15 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ real_upper = upper } + cargo = list() + if (logscale) cargo$logscale = TRUE + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune + Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = storage_type, - depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale") + depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 123fe73b..016d4a7d 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,8 +1,12 @@ #' @rdname Domain #' @export -p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) { +p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { + cargo = list() + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, - tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init) + tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamSet.R b/R/ParamSet.R index fb3fbcb4..a8ef87f0 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -93,7 +93,7 @@ ParamSet = R6Class("ParamSet", if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements private$.params = paramtbl # self$add_dep needs this - for (row in seq_len(nrow(paramtbl))) { + for (row in seq_len(nrow(paramtbl))) { for (req in requirements[[row]]) { invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies, .args = req) @@ -154,7 +154,8 @@ ParamSet = R6Class("ParamSet", #' @param any_tags (`character()`). See `$ids()`. #' @param type (`character(1)`)\cr #' Return values `"with_token"` (i.e. all values), - # `"without_token"` (all values that are not [`TuneToken`] objects) or `"only_token"` (only [`TuneToken`] objects)? + # `"without_token"` (all values that are not [`TuneToken`] objects), `"only_token"` (only [`TuneToken`] objects) + # or `"with_internal"` (all values that are no not `InternalTuneToken`)? #' @param check_required (`logical(1)`)\cr #' Check if all required parameters are set? #' @param remove_dependencies (`logical(1)`)\cr @@ -162,7 +163,7 @@ ParamSet = R6Class("ParamSet", #' @return Named `list()`. get_values = function(class = NULL, tags = NULL, any_tags = NULL, type = "with_token", check_required = TRUE, remove_dependencies = TRUE) { - assert_choice(type, c("with_token", "without_token", "only_token")) + assert_choice(type, c("with_token", "without_token", "only_token", "with_internal")) assert_flag(check_required) @@ -173,6 +174,8 @@ ParamSet = R6Class("ParamSet", values = discard(values, is, "TuneToken") } else if (type == "only_token") { values = keep(values, is, "TuneToken") + } else if (type == "with_internal") { + values = keep(values, is, "InternalTuneToken") } if (check_required) { @@ -255,6 +258,62 @@ ParamSet = R6Class("ParamSet", x }, + #' @description + #' + #' Aggregate parameter values according to their aggregation rules. + #' + #' @param x (named `list()` of `list()`s)\cr + #' The value(s) to be aggregated. Names are parameter values. + #' The aggregation function is selected based on the parameter. + #' + #' @return (named `list()`) + aggr_internal_tuned_values = function(x) { + assert_list(x, types = "list") + aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))] + assert_subset(names(x), aggrs$id) + if (!length(x)) { + return(named_list()) + } + imap(x, function(value, .id) { + if (!length(value)) { + stopf("Trying to aggregate values of parameters '%s', but there are no values", .id) + } + aggr = aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) + }) + }, + + #' @description + #' + #' Set the parameter values so that internal tuning for the selected parameters is disabled. + #' + #' @param ids (`character()`)\cr + #' The ids of the parameters for which to disable internal tuning. + #' @return `Self` + disable_internal_tuning = function(ids) { + assert_subset(ids, self$ids(tags = "internal_tuning")) + pvs = Reduce(c, map(private$.params[ids, "cargo", on = "id"][[1]], "disable_in_tune")) %??% named_list() + self$set_values(.values = pvs) + }, + + #' @description + #' Convert all parameters from the search space to parameter values using the transformation given by + #' `in_tune_fn`. + #' @param search_space ([`ParamSet`])\cr + #' The internal search space. + #' @return (named `list()`) + convert_internal_search_space = function(search_space) { + assert_class(search_space, "ParamSet") + param_vals = self$values + + imap(search_space$domains, function(token, .id) { + converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn + if (!is.function(converter)) { + stopf("No converter exists for parameter '%s'", .id) + } + converter(token, param_vals) + }) + }, + #' @description #' \pkg{checkmate}-like test-function. Takes a named list. #' Return `FALSE` if the given `$constraint` is not satisfied, `TRUE` otherwise. @@ -323,6 +382,14 @@ ParamSet = R6Class("ParamSet", if (!isTRUE(tunecheck)) return(tunecheck) } + xs_internaltune = keep(xs, is, "InternalTuneToken") + walk(names(xs_internaltune), function(pid) { + if ("internal_tuning" %nin% self$tags[[pid]]) { + stopf("Trying to assign InternalTuneToken to parameter '%s' which is not tagged with 'internal_tuning'.", pid) + } + }) + + # check each parameter group's feasibility xs_nontune = discard(xs, inherits, "TuneToken") @@ -822,7 +889,7 @@ ParamSet = R6Class("ParamSet", #' Note that this only refers to the `logscale` flag set during construction, e.g. `p_dbl(logscale = TRUE)`. #' If the parameter was set to logscale manually, e.g. through `p_dbl(trafo = exp)`, #' this `is_logscale` will be `FALSE`. - is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & cargo == "logscale", id)), + is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & map_lgl(cargo, function(x) isTRUE(x$logscale)), id)), ############################ # Per-Parameter class properties (S3 method call) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 2d604245..48ee2a38 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -148,6 +148,113 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, entry = if (n == "") length(private$.sets) + 1 else n private$.sets[[n]] = p invisible(self) + }, + + #' @description + #' + #' Set the parameter values so that internal tuning for the selected parameters is disabled. + #' + #' @param ids (`character()`)\cr + #' The ids of the parameters for which to disable internal tuning. + #' @return `Self` + disable_internal_tuning = function(ids) { + assert_subset(ids, self$ids(tags = "internal_tuning")) + + full_prefix = function(param_set, id_, prefix = "") { + info = get_private(param_set)$.translation[id_, c("owner_name", "original_id", "owner_ps_index"), on = "id"] + subset = get_private(param_set)$.sets[[info$owner_ps_index]] + prefix = if (info$owner_name == "") { + prefix + } else if (prefix == "") { + info$owner_name + } else { + paste0(prefix, ".", info$owner_name) + } + + if (!test_class(subset, "ParamSetCollection")) return(prefix) + + full_prefix(subset, info$original_id, prefix) + } + + pvs = Reduce(c, map(ids, function(id_) { + xs = private$.params[list(id_), "cargo", on = "id"][[1]][[1]]$disable_in_tune + prefix = full_prefix(self, id_) + if (prefix == "") return(xs) + set_names(xs, paste0(full_prefix(self, id_), ".", names(xs))) + })) %??% named_list() + self$set_values(.values = pvs) + }, + + #' @description + #' Convert all parameters from the search space to parameter values using the transformation given by + #' `in_tune_fn`. + #' @param search_space ([`ParamSet`])\cr + #' The internal search space. + #' @return (named `list()`) + convert_internal_search_space = function(search_space) { + assert_class(search_space, "ParamSet") + imap(search_space$domains, function(token, .id) { + converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn + if (!is.function(converter)) { + stopf("No converter exists for parameter '%s'", .id) + } + set_index = private$.translation[list(.id), "owner_ps_index", on = "id"][[1L]] + converter(token, private$.sets[[set_index]]$values) + }) + }, + + #' @description + #' Create a `ParamSet` from this `ParamSetCollection`. + flatten = function() { + flatps = super$flatten() + + recurse_prefix = function(id_, param_set, prefix = "") { + info = get_private(param_set)$.translation[list(id_), c("owner_name", "owner_ps_index"), on = "id"] + prefix = if (info$owner_name == "") { + prefix + } else if (prefix == "") { + info$owner_name + } else { + paste0(prefix, ".", info$owner_name) + } + subset = get_private(param_set)$.sets[[info$owner_ps_index]] + if (!test_class(subset, "ParamSetCollection")) { + return(list(prefix = prefix, ids = subset$ids())) + } + if (prefix != "") { + id_ = gsub(sprintf("^\\Q%s.\\E", prefix), "", id_) + } + recurse_prefix(id_, get_private(param_set)$.sets[[info$owner_ps_index]], prefix) + } + + flatps$.__enclos_env__$private$.params[, let( + cargo = pmap(list(cargo = cargo, id_ = id), function(cargo, id_) { + if (all(map_lgl(cargo[c("disable_in_tune", "in_tune_fn")], is.null))) return(cargo) + + info = recurse_prefix(id_, self) + prefix = info$prefix + if (prefix == "") return(cargo) + + in_tune_fn = cargo$in_tune_fn + + set_ids = info$ids + cargo$in_tune_fn = crate(function(domain, param_vals) { + param_vals = param_vals[names(param_vals) %in% paste0(prefix, ".", set_ids)] + names(param_vals) = gsub(sprintf("^\\Q%s.\\E", prefix), "", names(param_vals)) + in_tune_fn(domain, param_vals) + }, in_tune_fn, prefix, set_ids) + + if (length(cargo$disable_in_tune)) { + cargo$disable_in_tune = set_names( + cargo$disable_in_tune, + paste0(prefix, ".", names(cargo$disable_in_tune)) + ) + } + cargo + }) + )] + + flatps } ), diff --git a/R/ParamUty.R b/R/ParamUty.R index 61cbd143..5acd8fad 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -1,7 +1,7 @@ #' @rdname Domain #' @export -p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init) { +p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(custom_check, null.ok = TRUE) if (!is.null(custom_check)) { custom_check_result = custom_check(1) @@ -12,7 +12,12 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } else { "NoDefault" } - Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) + cargo = list(custom_check = custom_check, repr = repr) + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune + + Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) } #' @export diff --git a/R/to_tune.R b/R/to_tune.R index 9c41e6ea..03ef7927 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -41,6 +41,13 @@ #' The `TuneToken` object's internals are subject to change and should not be relied upon. `TuneToken` objects should #' only be constructed via `to_tune()`, and should only be used by giving them to `$values` of a [`ParamSet`]. #' @param ... if given, restricts the range to be tuning over, as described above. +#' @param internal (`logical(1)`)\cr +#' Whether to create an `InternalTuneToken`. +#' This is only available for parameters tagged with `"internal_tuning"`. +#' @param aggr (`function`)\cr +#' Function with one argument, which is a list of parameter values and returns a single aggregated value (e.g. the mean). +#' This specifies how multiple parameter values are aggregated to form a single value in the context of internal tuning. +#' If none specified, the default aggregation function of the parameter will be used. #' @return A `TuneToken` object. #' @examples #' params = ps( @@ -132,7 +139,12 @@ #' @family ParamSet construction helpers #' @aliases TuneToken #' @export -to_tune = function(...) { +to_tune = function(..., internal = !is.null(aggr), aggr = NULL) { + assert_flag(internal) + if (!is.null(aggr)) { + assert_true(internal) + } + assert_function(aggr, nargs = 1L, null.ok = TRUE) call = sys.call() if (...length() > 3) { stop("to_tune() must have zero arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`).") @@ -180,6 +192,17 @@ to_tune = function(...) { content = list(logscale = FALSE) } + if (internal) { + if (type == "ObjectTuneToken") { + stop("Internal tuning can currently not be combined with ParamSet or Domain object, specify lower and upper bounds, e.g. to_tune(upper = 100)") + } + if (isTRUE(content$logscale)) { + stop("Cannot combine logscale transformation with internal tuning.") + } + type = c("InternalTuneToken", type) + content$aggr = aggr + } + set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) } @@ -189,6 +212,13 @@ print.FullTuneToken = function(x, ...) { if (isTRUE(x$content$logscale)) " (log scale)" else "") } +#' @export +print.InternalTuneToken = function(x, ...) { + cat("Internal ") + NextMethod() +} + + #' @export print.RangeTuneToken = function(x, ...) { catf("Tuning over:\nrange [%s, %s]%s\n", x$content$lower %??% "...", x$content$upper %??% "...", @@ -208,11 +238,11 @@ print.ObjectTuneToken = function(x, ...) { # # Makes liberal use to `pslike_to_ps` (converting Param, ParamSet, Domain to ParamSet) # param is a data.table that is potentially modified by reference using data.table set() methods. -tunetoken_to_ps = function(tt, param) { +tunetoken_to_ps = function(tt, param, ...) { UseMethod("tunetoken_to_ps") } -tunetoken_to_ps.FullTuneToken = function(tt, param) { +tunetoken_to_ps.FullTuneToken = function(tt, param, ...) { if (!domain_is_bounded(param)) { stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id) } @@ -220,16 +250,32 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) { if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id) tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale), tt$call), param) } else { + if (!is.null(tt$content$aggr)) { + # https://github.com/Rdatatable/data.table/issues/6104 + param$cargo[[1L]] = list(insert_named(param$cargo[[1L]], list(aggr = tt$content$aggr))) + } pslike_to_ps(param, tt$call, param) } } -tunetoken_to_ps.RangeTuneToken = function(tt, param) { +tunetoken_to_ps.InternalTuneToken = function(tt, param, ...) { + # Calling NextMethod with additional arguments behaves weirdly, as the InternalTuneToken only works with ranges right now + # we just call it directly + aggr = if (!is.null(tt$content$aggr)) tt$content$aggr else param$cargo[[1L]]$aggr + if (is.null(aggr)) { + stopf("%s must specify a aggregation function for parameter %s", tt$call, param$id) + } + tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, tags = "internal_tuning", + aggr = aggr) +} + +tunetoken_to_ps.RangeTuneToken = function(tt, param, args = list(), ...) { if (!domain_is_number(param)) { stopf("%s for non-numeric param must have zero or one argument.", tt$call) } invalidpoints = discard(tt$content, function(x) is.null(x) || domain_test(param, set_names(list(x), param$id))) invalidpoints$logscale = NULL + invalidpoints$aggr = NULL if (length(invalidpoints)) { stopf("%s range not compatible with param %s.\nBad value(s):\n%s\nParameter:\n%s", tt$call, param$id, repr(invalidpoints), repr(param)) @@ -245,7 +291,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) { # create p_int / p_dbl object. Doesn't work if there is a numeric param class that we don't know about :-/ constructor = switch(param$cls, ParamInt = p_int, ParamDbl = p_dbl, stopf("%s: logscale for parameter %s of class %s not supported", tt$call, param$id, param$class)) - content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale) + content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, ...) pslike_to_ps(content, tt$call, param) } diff --git a/man/Domain.Rd b/man/Domain.Rd index d0c25a03..48c4e42f 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -20,7 +20,10 @@ p_dbl( depends = NULL, trafo = NULL, logscale = FALSE, - init + init, + aggr = NULL, + in_tune_fn = NULL, + disable_in_tune = NULL ) p_fct( @@ -30,7 +33,10 @@ p_fct( tags = character(), depends = NULL, trafo = NULL, - init + init, + aggr = NULL, + in_tune_fn = NULL, + disable_in_tune = NULL ) p_int( @@ -43,7 +49,10 @@ p_int( depends = NULL, trafo = NULL, logscale = FALSE, - init + init, + aggr = NULL, + in_tune_fn = NULL, + disable_in_tune = NULL ) p_lgl( @@ -52,7 +61,10 @@ p_lgl( tags = character(), depends = NULL, trafo = NULL, - init + init, + aggr = NULL, + in_tune_fn = NULL, + disable_in_tune = NULL ) p_uty( @@ -63,7 +75,10 @@ p_uty( depends = NULL, trafo = NULL, repr = substitute(default), - init + init, + aggr = NULL, + in_tune_fn = NULL, + disable_in_tune = NULL ) } \arguments{ @@ -139,6 +154,20 @@ defining domains or hyperparameter ranges of learning algorithms, because these Initial value. When this is given, then the corresponding entry in \code{ParamSet$values} is initialized with this value upon construction.} +\item{aggr}{(\code{function})\cr +Default aggregation function for a parameter. Can only be given for parameters tagged with \code{"internal_tuning"}. +Function with one argument, which is a list of parameter values and that returns the aggregated parameter value.} + +\item{in_tune_fn}{(\verb{function(domain, param_vals)})\cr +Function that converters a \code{Domain} object into a parameter value. +Can only be given for parameters tagged with \code{"internal_tuning"}. +This function should also assert that the parameters required to enable internal tuning for the given \code{domain} are +set in \code{param_vals} (such as \code{early_stopping_rounds} for \code{XGBoost}).} + +\item{disable_in_tune}{(named \code{list()})\cr +The parameter values that need to be set in the \code{ParamSet} to disable the internal tuning for the parameter. +For \code{XGBoost} this would e.g. be \code{list(early_stopping_rounds = NULL)}.} + \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that converts the names (if not given: \code{as.character()} of the values) of the \code{levels} argument to the values. @@ -222,6 +251,30 @@ print(grid) # ... but get transformed to integers. print(grid$transpose()) + +# internal tuning + +param_set = ps( + iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), + in_tune_fn = function(domain, param_vals) { + stopifnot(domain$lower <= 1) + stopifnot(param_vals$early_stopping == TRUE) + domain$upper + }, + disable_in_tune = list(early_stopping = FALSE)), + early_stopping = p_lgl() +) +param_set$set_values( + iters = to_tune(upper = 100, internal = TRUE), + early_stopping = TRUE +) +param_set$convert_internal_search_space(param_set$search_space()) +param_set$aggr_internal_tuned_values( + list(iters = list(1, 2, 3)) +) + +param_set$disable_internal_tuning("iters") +param_set$values$early_stopping } \seealso{ Other ParamSet construction helpers: diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index b74daebc..a4ad34e7 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -172,6 +172,9 @@ Named with param IDs.} \item \href{#method-ParamSet-get_values}{\code{ParamSet$get_values()}} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} +\item \href{#method-ParamSet-aggr_internal_tuned_values}{\code{ParamSet$aggr_internal_tuned_values()}} +\item \href{#method-ParamSet-disable_internal_tuning}{\code{ParamSet$disable_internal_tuning()}} +\item \href{#method-ParamSet-convert_internal_search_space}{\code{ParamSet$convert_internal_search_space()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} \item \href{#method-ParamSet-check}{\code{ParamSet$check()}} @@ -340,6 +343,71 @@ In almost all cases, the default \code{param_set = self} should be used.} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-aggr_internal_tuned_values}{}}} +\subsection{Method \code{aggr_internal_tuned_values()}}{ +Aggregate parameter values according to their aggregation rules. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ParamSet$aggr_internal_tuned_values(x)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{x}}{(named \code{list()} of \code{list()}s)\cr +The value(s) to be aggregated. Names are parameter values. +The aggregation function is selected based on the parameter.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +(named \code{list()}) +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-disable_internal_tuning}{}}} +\subsection{Method \code{disable_internal_tuning()}}{ +Set the parameter values so that internal tuning for the selected parameters is disabled. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ParamSet$disable_internal_tuning(ids)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{ids}}{(\code{character()})\cr +The ids of the parameters for which to disable internal tuning.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +\code{Self} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-convert_internal_search_space}{}}} +\subsection{Method \code{convert_internal_search_space()}}{ +Convert all parameters from the search space to parameter values using the transformation given by +\code{in_tune_fn}. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ParamSet$convert_internal_search_space(search_space)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{search_space}}{(\code{\link{ParamSet}})\cr +The internal search space.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +(named \code{list()}) +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-test_constraint}{}}} \subsection{Method \code{test_constraint()}}{ diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 93e4c93e..01c10e7f 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -72,6 +72,9 @@ This field provides direct references to the \code{\link{ParamSet}} objects.} \itemize{ \item \href{#method-ParamSetCollection-new}{\code{ParamSetCollection$new()}} \item \href{#method-ParamSetCollection-add}{\code{ParamSetCollection$add()}} +\item \href{#method-ParamSetCollection-disable_internal_tuning}{\code{ParamSetCollection$disable_internal_tuning()}} +\item \href{#method-ParamSetCollection-convert_internal_search_space}{\code{ParamSetCollection$convert_internal_search_space()}} +\item \href{#method-ParamSetCollection-flatten}{\code{ParamSetCollection$flatten()}} \item \href{#method-ParamSetCollection-clone}{\code{ParamSetCollection$clone()}} } } @@ -79,12 +82,12 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
Inherited methods