Skip to content

Commit

Permalink
more revisions from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Jan 10, 2025
1 parent d548ef7 commit a6b8467
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 62 deletions.
109 changes: 49 additions & 60 deletions R/CallbackSetLRScheduler.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#'
#' @description
#' Changes the learning rate based on the schedule specified by a `torch::lr_scheduler`.
#'
#' As of this writing, the following are available: [torch::lr_cosine_annealing()], [torch::lr_lambda()], [torch::lr_multiplicative()], [torch::lr_one_cycle()],
#' [torch::lr_reduce_on_plateau()], [torch::lr_step()], and custom schedulers defined with [torch::lr_scheduler()].
#'
#' @param .scheduler (`function`)\cr
#' The torch scheduler constructor function (e.g. `torch::lr_step`).
#' The `torch`` scheduler generator (e.g. `torch::lr_step`).
#' @param ... (`list`)\cr
#' The scheduler-specific arguments
#'
Expand All @@ -22,39 +25,27 @@ CallbackSetLRScheduler = R6Class("CallbackSetLRScheduler",
scheduler = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(.scheduler, ...) {
initialize = function(.scheduler, step_on_epoch, ...) {
assert_class(.scheduler, "lr_scheduler_generator")
self$scheduler_fn = .scheduler
private$.scheduler_args = list(...)
if (step_on_epoch) {
self$on_epoch_end = function() self$scheduler$step()
} else {
self$on_batch_end = function() self$scheduler$step()
}
},
#' @description
#' Creates the scheduler using the optimizer from the context
on_begin = function() {
# TODO: check that the .scheduler_args do not have the cb prefix (pretty sure this is true)
self$scheduler = invoke(self$scheduler_fn, optimizer = self$ctx$optimizer, .args = private$.scheduler_args)
},
#' @description
#' Depending on the scheduler, step after each epoch
on_epoch_end = function() {
# TODO: ensure that this happens after optimizer$step()
# https://blogs.rstudio.com/ai/posts/2020-10-19-torch-image-classification/#training
# but for now let's hope that it does
self$scheduler$step()
}
# TODO: add batches (really only for lr_scheduler_one_cycle)
# this does not need to be exposed to the user, we can pass an additional arg
# on_batch_end = function() {
# if (!self$step_on_epoch) {
# self$scheduler$step()
# }
# }
),
private = list(
.scheduler_args = NULL
)
)

# TODO: determine whether we should set ranges and such even when torch does not

# some of the schedulers accept lists
# so they can treat different parameter groups differently
check_class_or_list = function(x, classname) {
Expand All @@ -70,82 +61,82 @@ check_class_or_list = function(x, classname) {
}

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_cosine_annealing", function() {
mlr3torch_callbacks$add("lr_cosine_annealing", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
T_max = p_int(tags = c("train", "required")),
eta_min = p_dbl(default = 0, lower = 0, tags = "train"),
eta_min = p_dbl(default = 0, tags = "train"),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
id = "lr_cosine_annealing",
label = "Learning Rate Scheduler using Cosine Annealing",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_cosine_annealing)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_lambda", function() {
mlr3torch_callbacks$add("lr_lambda", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
lr_lambda = p_uty(tags = c("train"), custom_check = function(x) check_class_or_list(x, "function")), # TODO: assert fn or list of fns
last_epoch = p_int(default = -1, lower = -1, tags = "train"),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
label = "Learning Rate Scheduler using Multplication by a Function",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_lambda)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_multiplicative", function() {
mlr3torch_callbacks$add("lr_multiplicative", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
lr_lambda = p_uty(tags = c("train"), custom_check = function(x) check_class_or_list(x, "function")),
last_epoch = p_int(default = -1, lower = -1, tags = "train"),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
id = "lr_multiplicative",
label = "Learning Rate Scheduler using Multiplication by a Factor",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_multiplicative)
)
})

# TODO: refactor to operate on batches
#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_one_cycle", function() {
mlr3torch_callbacks$add("lr_one_cycle", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
max_lr = p_dbl(tags = "train"),
total_steps = p_int(default = NULL, tags = "train"), # set special vals to NULL
epochs = p_int(default = NULL, tags = "train"),
steps_per_epoch = NULL,
pct_start = p_dbl(default = 0.3, lower = 0, upper = 1, tags = "train"),
total_steps = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
epochs = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
steps_per_epoch = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
pct_start = p_dbl(default = 0.3,tags = "train"),
anneal_strategy = p_fct(default = "cos", levels = c("cos", "linear")), # this is a string in the torch fn
cycle_momentum = p_lgl(default = TRUE, tags = "train"),
base_momentum = p_uty(default = 0.85, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), # float or list
max_momentum = p_uty(default = 0.95, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), # or list
base_momentum = p_uty(default = 0.85, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
max_momentum = p_uty(default = 0.95, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
div_factor = p_dbl(default = 25, tags = "train"),
final_div_factor = p_dbl(default = 1e4, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
id = "lr_one_cycle",
label = "Learning Rate Scheduler using 1cycle",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_one_cycle, step_on_epoch = FALSE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_reduce_on_plateau", function() {
mlr3torch_callbacks$add("lr_reduce_on_plateau", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
Expand All @@ -159,45 +150,43 @@ mlr3torch_callbacks$add("lr_scheduler_reduce_on_plateau", function() {
eps = p_dbl(default = 1e-08, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
id = "lr_reduce_on_plateau",
label = "Learning Rate Scheduler using Reduction on Plateau",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_reduce_on_plateau)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_step", function() {
mlr3torch_callbacks$add("lr_step", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
step_size = p_int(default = 1, lower = 1, tags = "train"),
gamma = p_dbl(default = 0.1, lower = 0, upper = 1, tags = "train"),
step_size = p_int(default = 1, tags = "train"),
gamma = p_dbl(default = 0.1, tags = "train"),
last_epoch = p_int(default = -1, tags = "train")
),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
id = "lr_step",
label = "Learning Rate Scheduler using Step Decay",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_step)
)
})

as_lr_scheduler = function(x) {
#' @param x (`function`)\cr
#' The `torch` scheduler generator defined using `torch::lr_scheduler()`.
#' @param step_on_epoch (`logical(1)`)\cr
#' Whether the scheduler steps after every epoch
as_lr_scheduler = function(x, step_on_epoch) {
assert_class(x, "lr_scheduler_generator")
assert_flag(step_on_epoch)

TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = inferps(x),
id = "lr_scheduler",
label = "Learning Rate Scheduler",
id = "lr_scheduler_custom",
label = "Learning Rate Scheduler using Custom Policy",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = x)
additional_args = list(.scheduler = x, step_on_epoch = step_on_epoch)
)
}

# for a custom scheduler, instead of the following
# t_clbk("lr_step", ...)

# the user would write something like this
# custom_scheduler = function()
# as_lr_scheduler(custom_scheduler)
}
6 changes: 4 additions & 2 deletions tests/testthat/test_CallbackSetLRScheduler.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
test_that("autotest", {
cb = t_clbk("lr_scheduler_cosine_annealing", T_max = 10)
cb = t_clbk("lr_cosine_annealing", T_max = 10)
# each LR scheduler has a different paramset, so we don't test them
expect_torch_callback(cb, check_paramset = FALSE)
})

test_that("decay works", {
cb = t_clbk("lr_scheduler_step")
cb = t_clbk("ler_step")
expect_torch_callback(cb, check_paramset = FALSE)
task = tsk("iris")
n_epochs = 10
Expand All @@ -26,6 +27,7 @@ test_that("decay works", {
})

test_that("custom LR scheduler works", {
# modeled after lr_step
lr_subtract <- lr_scheduler(
"lr_subtract",
initialize = function(optimizer, step_size, delta = 0.1, last_epoch = -1) {
Expand Down

0 comments on commit a6b8467

Please sign in to comment.