From 9732d021972d328d558cfe04be974de36e7fcad5 Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Sat, 27 Jul 2024 17:07:44 +0200 Subject: [PATCH] properly manage dials parameters in parsnip, rename 2 `tabnet()` params to `rate_decay` and `rate_step_size` add tests for that --- DESCRIPTION | 1 + NAMESPACE | 13 +++++ R/dials.R | 97 +++++++++++++++++++++++++++++-------- R/parsnip.R | 56 ++++++++++----------- man/tabnet.Rd | 11 +---- man/tabnet_params.Rd | 22 ++++----- tests/testthat/test-dials.R | 48 ++++++++++++++++++ 7 files changed, 179 insertions(+), 69 deletions(-) create mode 100644 tests/testthat/test-dials.R diff --git a/DESCRIPTION b/DESCRIPTION index fcf265ca..c2bb6bf4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,6 +21,7 @@ BugReports: https://github.com/mlverse/tabnet/issues Depends: R (>= 3.6) Imports: + cli, coro, data.tree, dials, diff --git a/NAMESPACE b/NAMESPACE index 495de1c2..f99eb7c5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,23 +21,36 @@ S3method(tabnet_pretrain,recipe) S3method(update,tabnet) export("%>%") export(attention_width) +export(cat_emb_dim) export(check_compliant_node) +export(checkpoint_epochs) export(decision_width) +export(drop_last) +export(encoder_activation) export(feature_reusage) +export(lr_scheduler) export(mask_type) +export(mlp_activation) +export(mlp_hidden_multiplier) export(momentum) export(nn_prune_head.tabnet_fit) export(nn_prune_head.tabnet_pretrain) export(node_to_df) export(num_independent) +export(num_independent_decoder) export(num_shared) +export(num_shared_decoder) export(num_steps) +export(optimizer) +export(penalty) export(tabnet) export(tabnet_config) export(tabnet_explain) export(tabnet_fit) export(tabnet_nn) export(tabnet_pretrain) +export(verbose) +export(virtual_batch_size) importFrom(dplyr,filter) importFrom(dplyr,last_col) importFrom(dplyr,mutate) diff --git a/R/dials.R b/R/dials.R index 0c49d2d1..cc0d8e6e 100644 --- a/R/dials.R +++ b/R/dials.R @@ -17,56 +17,70 @@ check_dials <- function() { #' @rdname tabnet_params #' @return A `dials` parameter to be used when tuning TabNet models. #' @export -decision_width <- function(range = c(8L, 64L), trans = NULL) { +attention_width <- function(range = c(8L, 64L), trans = NULL) { check_dials() dials::new_quant_param( type = "integer", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(decision_width = "Width of the decision prediction layer"), + label = c(attention_width = "Width of the attention embedding for each mask"), finalize = NULL ) } #' @rdname tabnet_params #' @export -attention_width <- function(range = c(8L, 64L), trans = NULL) { +decision_width <- function(range = c(8L, 64L), trans = NULL) { check_dials() dials::new_quant_param( type = "integer", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(attention_width = "Width of the attention embedding for each mask"), + label = c(decision_width = "Width of the decision prediction layer"), finalize = NULL ) } + #' @rdname tabnet_params #' @export -num_steps <- function(range = c(3L, 10L), trans = NULL) { +feature_reusage <- function(range = c(1, 2), trans = NULL) { check_dials() dials::new_quant_param( - type = "integer", + type = "double", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(num_steps = "Number of steps in the architecture"), + label = c(feature_reusage = "Coefficient for feature reusage in the masks"), finalize = NULL ) } #' @rdname tabnet_params #' @export -feature_reusage <- function(range = c(1, 2), trans = NULL) { +momentum <- function(range = c(0.01, 0.4), trans = NULL) { check_dials() dials::new_quant_param( type = "double", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(feature_reusage = "Coefficient for feature reusage in the masks"), + label = c(momentum = "Momentum for batch normalization"), + finalize = NULL + ) +} + + +#' @rdname tabnet_params +#' @export +mask_type <- function(values = c("sparsemax", "entmax")) { + check_dials() + dials::new_qual_param( + type = "character", + values = values, + label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"), finalize = NULL ) } @@ -101,28 +115,69 @@ num_shared <- function(range = c(1L, 5L), trans = NULL) { #' @rdname tabnet_params #' @export -momentum <- function(range = c(0.01, 0.4), trans = NULL) { +num_steps <- function(range = c(3L, 10L), trans = NULL) { check_dials() dials::new_quant_param( - type = "double", + type = "integer", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(momentum = "Momentum for batch normalization"), + label = c(num_steps = "Number of steps in the architecture"), finalize = NULL ) } - -#' @rdname tabnet_params +#' @noRd #' @export -mask_type <- function(values = c("sparsemax", "entmax")) { +cat_emb_dim <- function(range = NULL, trans = NULL) { check_dials() - dials::new_qual_param( - type = "character", - values = values, - label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"), - finalize = NULL - ) + cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.") } +#' @noRd +#' @export +checkpoint_epochs <- cat_emb_dim + +#' @noRd +#' @export +drop_last <- cat_emb_dim + +#' @noRd +#' @export +encoder_activation <- cat_emb_dim + +#' @noRd +#' @export +lr_scheduler <- cat_emb_dim + +#' @noRd +#' @export +mlp_activation <- cat_emb_dim + +#' @noRd +#' @export +mlp_hidden_multiplier <- cat_emb_dim + +#' @noRd +#' @export +num_independent_decoder <- cat_emb_dim + +#' @noRd +#' @export +num_shared_decoder <- cat_emb_dim + +#' @noRd +#' @export +optimizer <- cat_emb_dim + +#' @noRd +#' @export +penalty <- cat_emb_dim + +#' @noRd +#' @export +verbose <- cat_emb_dim + +#' @noRd +#' @export +virtual_batch_size <- cat_emb_dim diff --git a/R/parsnip.R b/R/parsnip.R index de1d9a33..4b4e48a1 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -85,7 +85,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "cat_emb_dim", original = "cat_emb_dim", - func = list(pkg = "dials", fun = "cat_emb_dim"), + func = list(pkg = "tabnet", fun = "cat_emb_dim"), has_submodel = FALSE ) @@ -130,7 +130,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "mlp_hidden_multiplier", original = "mlp_hidden_multiplier", - func = list(pkg = "dials", fun = "mlp_hidden_multiplier"), + func = list(pkg = "tabnet", fun = "mlp_hidden_multiplier"), has_submodel = FALSE ) @@ -139,7 +139,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "mlp_activation", original = "mlp_activation", - func = list(pkg = "dials", fun = "mlp_activation"), + func = list(pkg = "tabnet", fun = "mlp_activation"), has_submodel = FALSE ) @@ -148,7 +148,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "encoder_activation", original = "encoder_activation", - func = list(pkg = "dials", fun = "encoder_activation"), + func = list(pkg = "tabnet", fun = "encoder_activation"), has_submodel = FALSE ) @@ -175,7 +175,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_independent_decoder", original = "num_independent_decoder", - func = list(pkg = "dials", fun = "num_independent_decoder"), + func = list(pkg = "tabnet", fun = "num_independent_decoder"), has_submodel = FALSE ) @@ -184,7 +184,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_shared_decoder", original = "num_shared_decoder", - func = list(pkg = "dials", fun = "num_shared_decoder"), + func = list(pkg = "tabnet", fun = "num_shared_decoder"), has_submodel = FALSE ) @@ -211,7 +211,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "momentum", original = "momentum", - func = list(pkg = "tabnet", fun = "momentum"), + func = list(pkg = "dials", fun = "momentum"), has_submodel = FALSE ) @@ -238,7 +238,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "virtual_batch_size", original = "virtual_batch_size", - func = list(pkg = "dials", fun = "virtual_batch_size"), + func = list(pkg = "tabnet", fun = "virtual_batch_size"), has_submodel = FALSE ) @@ -256,7 +256,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "optimizer", original = "optimizer", - func = list(pkg = "dials", fun = "optimizer"), + func = list(pkg = "tabnet", fun = "optimizer"), has_submodel = FALSE ) @@ -265,7 +265,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "loss", original = "loss", - func = list(pkg = "dials", fun = "loss"), + func = list(pkg = "tabnet", fun = "loss"), has_submodel = FALSE ) @@ -274,7 +274,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "clip_value", original = "clip_value", - func = list(pkg = "dials", fun = "clip_value"), + func = list(pkg = "tabnet", fun = "clip_value"), has_submodel = FALSE ) @@ -283,7 +283,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "drop_last", original = "drop_last", - func = list(pkg = "dials", fun = "drop_last"), + func = list(pkg = "tabnet", fun = "drop_last"), has_submodel = FALSE ) @@ -292,25 +292,25 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "lr_scheduler", original = "lr_scheduler", - func = list(pkg = "dials", fun = "lr_scheduler"), + func = list(pkg = "tabnet", fun = "lr_scheduler"), has_submodel = FALSE ) parsnip::set_model_arg( model = "tabnet", eng = "torch", - parsnip = "lr_decay", + parsnip = "rate_decay", original = "lr_decay", - func = list(pkg = "dials", fun = "lr_decay"), + func = list(pkg = "dials", fun = "rate_decay"), has_submodel = FALSE ) parsnip::set_model_arg( model = "tabnet", eng = "torch", - parsnip = "step_size", + parsnip = "rate_step_size", original = "step_size", - func = list(pkg = "dials", fun = "step_size"), + func = list(pkg = "dials", fun = "rate_step_size"), has_submodel = FALSE ) @@ -319,7 +319,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "checkpoint_epochs", original = "checkpoint_epochs", - func = list(pkg = "dials", fun = "checkpoint_epochs"), + func = list(pkg = "tabnet", fun = "checkpoint_epochs"), has_submodel = FALSE ) @@ -328,7 +328,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "verbose", original = "verbose", - func = list(pkg = "dials", fun = "verbose"), + func = list(pkg = "tabnet", fun = "verbose"), has_submodel = FALSE ) @@ -337,7 +337,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "importance_sample_size", original = "importance_sample_size", - func = list(pkg = "dials", fun = "importance_sample_size"), + func = list(pkg = "tabnet", fun = "importance_sample_size"), has_submodel = FALSE ) @@ -346,7 +346,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "early_stopping_monitor", original = "early_stopping_monitor", - func = list(pkg = "dials", fun = "early_stopping_monitor"), + func = list(pkg = "tabnet", fun = "early_stopping_monitor"), has_submodel = FALSE ) @@ -355,7 +355,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "early_stopping_tolerance", original = "early_stopping_tolerance", - func = list(pkg = "dials", fun = "early_stopping_tolerance"), + func = list(pkg = "tabnet", fun = "early_stopping_tolerance"), has_submodel = FALSE ) @@ -364,7 +364,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "early_stopping_patience", original = "early_stopping_patience", - func = list(pkg = "dials", fun = "early_stopping_patience"), + func = list(pkg = "tabnet", fun = "early_stopping_patience"), has_submodel = FALSE ) @@ -382,7 +382,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "tabnet_model", original = "tabnet_model", - func = list(pkg = "dials", fun = "tabnet_model"), + func = list(pkg = "tabnet", fun = "tabnet_model"), has_submodel = FALSE ) @@ -391,7 +391,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "from_epoch", original = "from_epoch", - func = list(pkg = "dials", fun = "from_epoch"), + func = list(pkg = "tabnet", fun = "from_epoch"), has_submodel = FALSE ) @@ -449,7 +449,7 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, num_independent_decoder = NULL, num_shared_decoder = NULL, penalty = NULL, feature_reusage = NULL, momentum = NULL, epochs = NULL, batch_size = NULL, virtual_batch_size = NULL, learn_rate = NULL, optimizer = NULL, loss = NULL, - clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, lr_decay = NULL, step_size = NULL, + clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, rate_decay = NULL, rate_step_size = NULL, checkpoint_epochs = NULL, verbose = NULL, importance_sample_size = NULL, early_stopping_monitor = NULL, early_stopping_tolerance = NULL, early_stopping_patience = NULL, skip_importance = NULL, @@ -488,8 +488,8 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, clip_value = rlang::enquo(clip_value), drop_last = rlang::enquo(drop_last), lr_scheduler = rlang::enquo(lr_scheduler), - lr_decay = rlang::enquo(lr_decay), - step_size = rlang::enquo(step_size), + lr_decay = rlang::enquo(rate_decay), + step_size = rlang::enquo(rate_step_size), checkpoint_epochs = rlang::enquo(checkpoint_epochs), verbose = rlang::enquo(verbose), importance_sample_size = rlang::enquo(importance_sample_size), diff --git a/man/tabnet.Rd b/man/tabnet.Rd index 8184a81d..0705d667 100644 --- a/man/tabnet.Rd +++ b/man/tabnet.Rd @@ -27,8 +27,8 @@ tabnet( clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, - lr_decay = NULL, - step_size = NULL, + rate_decay = NULL, + rate_step_size = NULL, checkpoint_epochs = NULL, verbose = NULL, importance_sample_size = NULL, @@ -113,13 +113,6 @@ decays the learning rate by \code{lr_decay} when no improvement after \code{step It can also be a \link[torch:lr_scheduler]{torch::lr_scheduler} function that only takes the optimizer as parameter. The \code{step} method is called once per epoch.} -\item{lr_decay}{multiplies the initial learning rate by \code{lr_decay} every -\code{step_size} epochs. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} -or \code{NULL}.} - -\item{step_size}{the learning rate scheduler step size. Unused if -\code{lr_scheduler} is a \code{torch::lr_scheduler} or \code{NULL}.} - \item{checkpoint_epochs}{checkpoint model weights and architecture every \code{checkpoint_epochs}. (default is 10). This may cause large memory usage. Use \code{0} to disable checkpoints.} diff --git a/man/tabnet_params.Rd b/man/tabnet_params.Rd index 239eb0b9..cdcccfea 100644 --- a/man/tabnet_params.Rd +++ b/man/tabnet_params.Rd @@ -1,31 +1,31 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/dials.R -\name{decision_width} -\alias{decision_width} +\name{attention_width} \alias{attention_width} -\alias{num_steps} +\alias{decision_width} \alias{feature_reusage} -\alias{num_independent} -\alias{num_shared} \alias{momentum} \alias{mask_type} +\alias{num_independent} +\alias{num_shared} +\alias{num_steps} \title{Parameters for the tabnet model} \usage{ -decision_width(range = c(8L, 64L), trans = NULL) - attention_width(range = c(8L, 64L), trans = NULL) -num_steps(range = c(3L, 10L), trans = NULL) +decision_width(range = c(8L, 64L), trans = NULL) feature_reusage(range = c(1, 2), trans = NULL) +momentum(range = c(0.01, 0.4), trans = NULL) + +mask_type(values = c("sparsemax", "entmax")) + num_independent(range = c(1L, 5L), trans = NULL) num_shared(range = c(1L, 5L), trans = NULL) -momentum(range = c(0.01, 0.4), trans = NULL) - -mask_type(values = c("sparsemax", "entmax")) +num_steps(range = c(3L, 10L), trans = NULL) } \arguments{ \item{range}{the default range for the parameter value} diff --git a/tests/testthat/test-dials.R b/tests/testthat/test-dials.R new file mode 100644 index 00000000..e4200dff --- /dev/null +++ b/tests/testthat/test-dials.R @@ -0,0 +1,48 @@ +test_that("Check we can use hardhat:::extract_parameter_set_dials() with {dial} tune()ed parameter", { + + model <- tabnet(batch_size = tune(), learn_rate = tune(), epochs = tune(), + momentum = tune(), penalty = tune(), rate_step_size = tune()) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) + + expect_no_error( + wf %>% hardhat::extract_parameter_set_dials() + ) +}) + +test_that("Check we can use hardhat:::extract_parameter_set_dials() with {tabnet} tune()ed parameter", { + + model <- tabnet(num_steps = tune(), num_shared = tune(), mask_type = tune(), + feature_reusage = tune(), attention_width = tune()) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) + + expect_no_error( + wf %>% hardhat::extract_parameter_set_dials() + ) +}) + +test_that("Check non supported tune()ed parameter raise an explicit error", { + + model <- tabnet(cat_emb_dim = tune(), checkpoint_epochs = 0) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) + + expect_error( + wf %>% hardhat::extract_parameter_set_dials(), + regexp = "cannot be used as a .* parameter yet" + ) +}) +