Skip to content

Commit

Permalink
properly manage dials parameters in parsnip,
Browse files Browse the repository at this point in the history
rename 2 `tabnet()` params to `rate_decay` and `rate_step_size`
add tests for that
  • Loading branch information
cregouby committed Jul 27, 2024
1 parent 6ae7160 commit 9732d02
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 69 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ BugReports: https://github.com/mlverse/tabnet/issues
Depends:
R (>= 3.6)
Imports:
cli,
coro,
data.tree,
dials,
Expand Down
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,36 @@ S3method(tabnet_pretrain,recipe)
S3method(update,tabnet)
export("%>%")
export(attention_width)
export(cat_emb_dim)
export(check_compliant_node)
export(checkpoint_epochs)
export(decision_width)
export(drop_last)
export(encoder_activation)
export(feature_reusage)
export(lr_scheduler)
export(mask_type)
export(mlp_activation)
export(mlp_hidden_multiplier)
export(momentum)
export(nn_prune_head.tabnet_fit)
export(nn_prune_head.tabnet_pretrain)
export(node_to_df)
export(num_independent)
export(num_independent_decoder)
export(num_shared)
export(num_shared_decoder)
export(num_steps)
export(optimizer)
export(penalty)
export(tabnet)
export(tabnet_config)
export(tabnet_explain)
export(tabnet_fit)
export(tabnet_nn)
export(tabnet_pretrain)
export(verbose)
export(virtual_batch_size)
importFrom(dplyr,filter)
importFrom(dplyr,last_col)
importFrom(dplyr,mutate)
Expand Down
97 changes: 76 additions & 21 deletions R/dials.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,70 @@ check_dials <- function() {
#' @rdname tabnet_params
#' @return A `dials` parameter to be used when tuning TabNet models.
#' @export
decision_width <- function(range = c(8L, 64L), trans = NULL) {
attention_width <- function(range = c(8L, 64L), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(decision_width = "Width of the decision prediction layer"),
label = c(attention_width = "Width of the attention embedding for each mask"),
finalize = NULL
)
}

#' @rdname tabnet_params
#' @export
attention_width <- function(range = c(8L, 64L), trans = NULL) {
decision_width <- function(range = c(8L, 64L), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(attention_width = "Width of the attention embedding for each mask"),
label = c(decision_width = "Width of the decision prediction layer"),
finalize = NULL
)
}


#' @rdname tabnet_params
#' @export
num_steps <- function(range = c(3L, 10L), trans = NULL) {
feature_reusage <- function(range = c(1, 2), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "integer",
type = "double",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(num_steps = "Number of steps in the architecture"),
label = c(feature_reusage = "Coefficient for feature reusage in the masks"),
finalize = NULL
)
}

#' @rdname tabnet_params
#' @export
feature_reusage <- function(range = c(1, 2), trans = NULL) {
momentum <- function(range = c(0.01, 0.4), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "double",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(feature_reusage = "Coefficient for feature reusage in the masks"),
label = c(momentum = "Momentum for batch normalization"),
finalize = NULL
)
}


#' @rdname tabnet_params
#' @export
mask_type <- function(values = c("sparsemax", "entmax")) {
check_dials()
dials::new_qual_param(
type = "character",
values = values,
label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"),
finalize = NULL
)
}
Expand Down Expand Up @@ -101,28 +115,69 @@ num_shared <- function(range = c(1L, 5L), trans = NULL) {

#' @rdname tabnet_params
#' @export
momentum <- function(range = c(0.01, 0.4), trans = NULL) {
num_steps <- function(range = c(3L, 10L), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "double",
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(momentum = "Momentum for batch normalization"),
label = c(num_steps = "Number of steps in the architecture"),
finalize = NULL
)
}


#' @rdname tabnet_params
#' @noRd
#' @export
mask_type <- function(values = c("sparsemax", "entmax")) {
cat_emb_dim <- function(range = NULL, trans = NULL) {
check_dials()
dials::new_qual_param(
type = "character",
values = values,
label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"),
finalize = NULL
)
cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.")
}

#' @noRd
#' @export
checkpoint_epochs <- cat_emb_dim

#' @noRd
#' @export
drop_last <- cat_emb_dim

#' @noRd
#' @export
encoder_activation <- cat_emb_dim

#' @noRd
#' @export
lr_scheduler <- cat_emb_dim

#' @noRd
#' @export
mlp_activation <- cat_emb_dim

#' @noRd
#' @export
mlp_hidden_multiplier <- cat_emb_dim

#' @noRd
#' @export
num_independent_decoder <- cat_emb_dim

#' @noRd
#' @export
num_shared_decoder <- cat_emb_dim

#' @noRd
#' @export
optimizer <- cat_emb_dim

#' @noRd
#' @export
penalty <- cat_emb_dim

#' @noRd
#' @export
verbose <- cat_emb_dim

#' @noRd
#' @export
virtual_batch_size <- cat_emb_dim
Loading

0 comments on commit 9732d02

Please sign in to comment.