Skip to content

Commit

Permalink
aggr is now part of cargo
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 22, 2024
1 parent 5fd31c2 commit b6fee52
Show file tree
Hide file tree
Showing 17 changed files with 85 additions and 106 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export(generate_design_grid)
export(generate_design_lhs)
export(generate_design_random)
export(generate_design_sobol)
export(in_tune)
export(p_dbl)
export(p_fct)
export(p_int)
Expand Down
10 changes: 3 additions & 7 deletions R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ Domain = function(cls, grouping,
trafo = NULL,
depends_expr = NULL,
storage_type = "list",
init,
aggr = NULL) {
init) {

assert_string(cls)
assert_string(grouping)
Expand All @@ -151,7 +150,6 @@ Domain = function(cls, grouping,
if (length(special_vals) && !is.null(trafo)) stop("trafo and special_values can not both be given at the same time.")
assert_character(tags, any.missing = FALSE, unique = TRUE)
assert_function(trafo, null.ok = TRUE)
assert_function(aggr, null.ok = TRUE, nargs = 1L)

# depends may be an expression, but may also be quote() or expression()
if (length(depends_expr) == 1) {
Expand All @@ -174,8 +172,7 @@ Domain = function(cls, grouping,
.trafo = list(trafo),
.requirements = list(parse_depends(depends_expr, parent.frame(2))),
.init_given = !missing(init),
.init = list(if (!missing(init)) init),
.aggr = list(aggr)
.init = list(if (!missing(init)) init)
)

class(param) = c(cls, "Domain", class(param))
Expand Down Expand Up @@ -220,8 +217,7 @@ empty_domain = data.table(id = character(0), cls = character(0), grouping = char
.trafo = list(),
.requirements = list(),
.init_given = logical(0),
.init = list(),
.aggr = list()
.init = list()
)

domain_names = names(empty_domain)
Expand Down
7 changes: 6 additions & 1 deletion R/ParamDbl.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @rdname Domain
#' @export
p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) {
assert_function(aggr, null.ok = TRUE, nargs = 1L)
assert_number(tolerance, lower = 0)
assert_number(lower)
assert_number(upper)
Expand All @@ -17,8 +18,12 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
real_upper = upper
}

cargo = list()
if (logscale) cargo$logscale = TRUE
cargo$aggr = aggr

Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric",
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr)
depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo)
}

#' @export
Expand Down
3 changes: 2 additions & 1 deletion R/ParamFct.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @rdname Domain
#' @export
p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) {
assert_function(aggr, null.ok = TRUE, nargs = 1L)
constargs = as.list(match.call()[-1])
levels = eval.parent(constargs$levels)
if (!is.character(levels)) {
Expand All @@ -22,7 +23,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact
# group p_fct by levels, so the group can be checked in a vectorized fashion.
# We escape '"' and '\' to '\"' and '\\', respectively.
grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",")
Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, aggr = aggr)
Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, cargo = if (!is.null(aggr)) list(aggr = aggr))
}

#' @export
Expand Down
7 changes: 6 additions & 1 deletion R/ParamInt.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @rdname Domain
#' @export
p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) {
assert_function(aggr, null.ok = TRUE, nargs = 1L)
assert_number(tolerance, lower = 0, upper = 0.5)
# assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound
if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower)
Expand All @@ -23,9 +24,13 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
real_upper = upper
}

cargo = list()
if (logscale) cargo$logscale = TRUE
cargo$aggr = aggr

Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo,
storage_type = storage_type,
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr)
depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo)
}

#' @export
Expand Down
3 changes: 2 additions & 1 deletion R/ParamLgl.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#' @rdname Domain
#' @export
p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) {
assert_function(aggr, null.ok = TRUE, nargs = 1L)
Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default,
tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, aggr = aggr)
tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (!is.null(aggr)) list(aggr = aggr))
}

#' @export
Expand Down
26 changes: 9 additions & 17 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#' - special_vals: list col of list
#' - default: list col
#' - storage_type: character
#' - tags: list col of character vectors
#' - tags: list col of character vectorssearch
#' @examples
#' pset = ParamSet$new(
#' params = list(
Expand Down Expand Up @@ -90,10 +90,6 @@ ParamSet = R6Class("ParamSet",
private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id")
}

if (".aggr" %in% names(paramtbl)) {
private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id")
}

if (".requirements" %in% names(paramtbl)) {
requirements = paramtbl$.requirements
private$.params = paramtbl # self$add_dep needs this
Expand Down Expand Up @@ -265,20 +261,21 @@ ParamSet = R6Class("ParamSet",
#'
#' @param x (named `list()` of `list()`s)\cr
#' The value(s) to be aggregated. Names are parameter values.
#' The aggregation function is selected accordingly for each parameter.
#' The aggregation function is selected based on the parameter.
#' @return (named `list()`)
aggr = function(x) {
assert_list(x, types = "list")
assert_permutation(names(x), private$.aggrs$id)
aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))]
assert_permutation(names(x), aggrs$id)
if (!(length(unique(lengths(x))) == 1L)) {
stopf("The same number of values are required for each parameter")
}
if (nrow(private$.aggrs) && !length(x[[1L]])) {
stopf("More than one value is required to aggregate them")
if (nrow(aggrs) && !length(x[[1L]])) {
stopf("At least one value is required to aggregate them")
}

imap(x, function(value, .id) {
aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value)
aggr = aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value)
})
},

Expand Down Expand Up @@ -529,7 +526,6 @@ ParamSet = R6Class("ParamSet",
.trafo = private$.trafos[id, trafo],
.requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps
.init_given = id %in% names(vals),
.aggr = private$.aggrs[id, get("aggr")],
.init = unname(vals[id]))
]

Expand Down Expand Up @@ -564,7 +560,6 @@ ParamSet = R6Class("ParamSet",

result$.__enclos_env__$private$.params = setindexv(private$.params[ids, on = "id"], c("id", "cls", "grouping"))
result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[ids, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[ids, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.tags = setkeyv(private$.tags[ids, on = "id", nomatch = NULL], "id")
result$assert_values = FALSE
result$deps = deps[ids, on = "id", nomatch = NULL]
Expand Down Expand Up @@ -592,7 +587,6 @@ ParamSet = R6Class("ParamSet",
result$.__enclos_env__$private$.params = setindexv(private$.params[get_id, on = "id"], c("id", "cls", "grouping"))
# setkeyv not strictly necessary since get_id is scalar, but we do it for consistency
result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[get_id, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[get_id, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.tags = setkeyv(private$.tags[get_id, on = "id", nomatch = NULL], "id")
result$assert_values = FALSE
result$values = values[match(get_id, names(values), nomatch = 0)]
Expand Down Expand Up @@ -744,7 +738,6 @@ ParamSet = R6Class("ParamSet",
result = copy(private$.params)
result[, .tags := list(self$tags)]
result[private$.trafos, .trafo := list(trafo), on = "id"]
result[private$.aggrs, .aggr := list(aggr), on = "id"]
result[self$deps, .requirements := transpose_list(.(on, cond)), on = "id"]
vals = self$values
result[, `:=`(
Expand Down Expand Up @@ -852,7 +845,7 @@ ParamSet = R6Class("ParamSet",
#' Note that this only refers to the `logscale` flag set during construction, e.g. `p_dbl(logscale = TRUE)`.
#' If the parameter was set to logscale manually, e.g. through `p_dbl(trafo = exp)`,
#' this `is_logscale` will be `FALSE`.
is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & cargo == "logscale", id)),
is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & map_lgl(cargo, function(x) isTRUE(x$logscale)), id)),

############################
# Per-Parameter class properties (S3 method call)
Expand Down Expand Up @@ -903,7 +896,6 @@ ParamSet = R6Class("ParamSet",
.tags = data.table(id = character(0L), tag = character(0), key = "id"),
.deps = data.table(id = character(0L), on = character(0L), cond = list()),
.trafos = data.table(id = character(0L), trafo = list(), key = "id"),
.aggrs = data.table(id = character(0L), aggr = list(), key = "id"),

get_tune_ps = function(values) {
values = keep(values, inherits, "TuneToken")
Expand All @@ -915,7 +907,7 @@ ParamSet = R6Class("ParamSet",
names(params) = names(values)

# package-internal S3 fails if we don't call the function indirectly here
partsets = pmap(list(values, params), function(...) tunetoken_to_ps(..., param_set = param_set))
partsets = pmap(list(values, params), function(...) tunetoken_to_ps(..., param_set = self))

pars = ps_union(partsets) # partsets does not have names here, wihch is what we want.

Expand Down
5 changes: 0 additions & 5 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
private$.tags = paramtbl[, .(tag = unique(unlist(.tags))), keyby = "id"]

private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id")
private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id")

private$.translation = paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE]
setkeyv(private$.translation, "id")
Expand Down Expand Up @@ -126,10 +125,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
if (nrow(newtrafos)) {
private$.trafos = setkeyv(rbind(private$.trafos, newtrafos), "id")
}
newaggrs = paramtbl[!map_lgl(.aggr, is.null), .(id, trafo = .aggr)]
if (nrow(newaggrs)) {
private$.aggrs = setkeyv(rbind(private$.aggrs, newaggrs), "id")
}

private$.translation = rbind(private$.translation, paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE])
setkeyv(private$.translation, "id")
Expand Down
6 changes: 5 additions & 1 deletion R/ParamUty.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @export
p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL) {
assert_function(custom_check, null.ok = TRUE)
assert_function(aggr, null.ok = TRUE, nargs = 1L)
if (!is.null(custom_check)) {
custom_check_result = custom_check(1)
assert(check_true(custom_check_result), check_string(custom_check_result), .var.name = "The result of 'custom_check()'")
Expand All @@ -12,7 +13,10 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t
} else {
"NoDefault"
}
Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init, aggr = aggr)
cargo = list(custom_check = custom_check, repr = repr)
cargo$aggr = aggr

Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init)
}

#' @export
Expand Down
8 changes: 0 additions & 8 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,3 @@ col_to_nl = function(dt, col = 1, idcol = 2) {
names(data) = dt[[idcol]]
data
}

default_aggr = function(x) {
if (!test_numeric(x[[1]], len = 1L)) {
stopf("Provide a custom aggregator for non-numeric and non-scalar parameters.")
}
ceiling(mean(unlist(x)))
}

45 changes: 24 additions & 21 deletions R/to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
#' The `TuneToken` object's internals are subject to change and should not be relied upon. `TuneToken` objects should
#' only be constructed via `to_tune()`, and should only be used by giving them to `$values` of a [`ParamSet`].
#' @param ... if given, restricts the range to be tuning over, as described above.
#' @param aggr (`function`)\cr
#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value.
#' If `NULL`, the default aggregation function of the parameter will be used.\
#' @param inner (`logical(1)`)\cr
#' Whether to create an inner tuning token, i.e. the value will be optimized using the `Learner`-internal tuning
#' mechanism, such as early stopping for XGBoost.
#' @return A `TuneToken` object.
#' @examples
#' params = ps(
Expand All @@ -54,7 +60,8 @@
#' uty2 = p_uty(),
#' uty3 = p_uty(),
#' uty4 = p_uty(),
#' uty5 = p_uty()
#' uty5 = p_uty(),
#' p_inner = p_int(tags = "inner_tuning", aggr = function(x) round(mean(unlist(x))))
#' )
#'
#' params$values = list(
Expand Down Expand Up @@ -101,7 +108,10 @@
#' )),
#'
#' # not all values need to be tuned!
#' uty5 = 100
#' uty5 = 100,
#'
#' # Fix value to 100, but use learner-internal tuning
#' p_inner = to_tune(p_fct(100), inner = TRUE))
#' )
#'
#' print(params$values)
Expand Down Expand Up @@ -132,7 +142,12 @@
#' @family ParamSet construction helpers
#' @aliases TuneToken
#' @export
to_tune = function(...) {
to_tune = function(..., inner = !is.null(aggr), aggr = NULL) {
test_function(aggr, nargs = 1L, null.ok = TRUE)
assert_flag(inner)
if (!is.null(aggr)) {
assert_true(inner)
}
call = sys.call()
if (...length() > 3) {
stop("to_tune() must have zero arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`).")
Expand Down Expand Up @@ -180,24 +195,12 @@ to_tune = function(...) {
content = list(logscale = FALSE)
}

set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken"))
}
if (inner) {
type = c("InnerTuneToken", type)
}
if (!is.null(aggr)) content$aggr = aggr

#' @title Create an Inner Tuning Token
#' @description
#' Works just like [`to_tune()`], but marks the parameter for inner tuning.
#' See [`mlr3::Learner`] for more information.
#' @inheritParams to_tune
#' @param aggr (`function`)\cr
#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value.
#' The default is to average the values and round them up.
#' @export
in_tune = function(..., aggr = NULL) {
test_function(aggr, nargs = 1L, null.ok = TRUE)
tt = to_tune(...)
if (!is.null(aggr)) tt$content$aggr = aggr
tt = set_class(tt, classes = c("InnerTuneToken", class(tt)))
return(tt)
set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken"))
}

#' @export
Expand Down Expand Up @@ -236,7 +239,7 @@ tunetoken_to_ps = function(tt, param, param_set) {
}

tunetoken_to_ps.InnerTuneToken = function(tt, param, param_set) {
tt$content$aggr = tt$content$aggr %??% get_private(param_set)$.aggrs[list(param$id), "aggr", on = "id"][[1L]][[1L]]
tt$content$aggr = tt$content$aggr %??% param_set$params[list(param$id), "cargo", on = "id"][[1L]][[1L]]$aggr
if ("inner_tuning" %nin% param_set$tags[[param$id]]) {
stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id)
}
Expand Down
4 changes: 2 additions & 2 deletions man/ParamSet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b6fee52

Please sign in to comment.