From f5b2f6f5aa05980a38116dcda18a4fcfbc7a90cc Mon Sep 17 00:00:00 2001 From: cregouby Date: Mon, 4 Dec 2023 09:17:07 +0100 Subject: [PATCH] Switch to zeallot `%<-%` for list ouput (#133) * pre-test move to zeallot * res is already computed and part of encoder result * fix the snapshot update coming from new values for `num_xx_decoder` * switch CI to ubuntu 20.04 for GPU Image availabilty * fix translation cannot run on CI as GPU R in not multi-lingual --- .github/workflows/R-CMD-check.yaml | 2 +- DESCRIPTION | 7 ++++++- NAMESPACE | 1 + NEWS.md | 6 ++++-- R/explain.R | 8 +++++--- R/model.R | 28 ++++++++++++++++------------ R/tab-network.R | 16 ++++++---------- R/utils-pipe.R | 1 + tests/testthat/_snaps/pretraining.md | 4 ++-- tests/testthat/test_translations.R | 6 ++++-- 10 files changed, 46 insertions(+), 33 deletions(-) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 864b7f87..06ca268e 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -58,7 +58,7 @@ jobs: name: 'gpu' container: - image: 'nvidia/cuda:11.6.2-cudnn8-devel-ubuntu18.04' + image: 'nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04' options: '--gpus all --runtime=nvidia' timeout-minutes: 120 diff --git a/DESCRIPTION b/DESCRIPTION index 1a1969f1..69545b40 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,6 +18,8 @@ Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.3 URL: https://mlverse.github.io/tabnet/, https://github.com/mlverse/tabnet BugReports: https://github.com/mlverse/tabnet/issues +Depends: + R (>= 3.6) Imports: torch (>= 0.4.0), hardhat (>= 1.3.0), @@ -29,7 +31,8 @@ Imports: tibble, tidyr, coro, - vctrs + vctrs, + zeallot Suggests: testthat (>= 3.0.0), data.tree, @@ -52,3 +55,5 @@ Suggests: yardstick VignetteBuilder: knitr Config/testthat/edition: 3 +Config/testthat/parallel: false +Config/testthat/start-first: interface, explain, params diff --git a/NAMESPACE b/NAMESPACE index 09ccef40..bca75a10 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -50,3 +50,4 @@ importFrom(rlang,.data) importFrom(stats,predict) importFrom(stats,update) importFrom(tidyr,replace_na) +importFrom(zeallot,"%<-%") diff --git a/NEWS.md b/NEWS.md index a47a2eee..4d47a9ce 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,10 +1,12 @@ # tabnet (development version) ## New features -* add FR translation (#131) -* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129) + * {tabnet} now allows hierarchical multi-label classification through {data.tree} hierarchical `Node` dataset. (#126) +* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129) * Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120) +* use zeallot internally with %<-% for code readability (#133) +* add FR translation (#131) # tabnet 0.4.0 diff --git a/R/explain.R b/R/explain.R index 2394efe7..12c26154 100644 --- a/R/explain.R +++ b/R/explain.R @@ -94,19 +94,21 @@ explain_impl <- function(network, x, x_na_mask) { network$to(device = curr_device) }) network$to(device=x$device) - outputs <- network$forward_masks(x, x_na_mask) + # NULLing values to avoid a R-CMD Check Note "No visible binding for global variable" + M_explain_emb_dim <- masks_emb_dim <- NULL + c(M_explain_emb_dim, masks_emb_dim) %<-% network$forward_masks(x, x_na_mask) # summarize the categorical embeddedings into 1 column # per variable M_explain <- sum_embedding_masks( - mask = outputs[[1]], + mask = M_explain_emb_dim, input_dim = network$input_dim, cat_idx = network$cat_idxs, cat_emb_dim = network$cat_emb_dim ) masks <- lapply( - outputs[[2]], + masks_emb_dim, FUN = sum_embedding_masks, input_dim = network$input_dim, cat_idx = network$cat_idxs, diff --git a/R/model.R b/R/model.R index b2bf32f8..c09d62f7 100644 --- a/R/model.R +++ b/R/model.R @@ -256,8 +256,10 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) { } train_batch <- function(network, optimizer, batch, config) { + # NULLing values to avoid a R-CMD Check Note "No visible binding for global variable" + out <- M_loss <- NULL # forward pass - output <- network(batch$x, batch$x_na_mask) + c(out, M_loss) %<-% network(batch$x, batch$x_na_mask) # if target is multi-outcome, loss has to be applied to each label-group if (max(batch$output_dim$shape) > 1) { # multi-outcome @@ -266,7 +268,7 @@ train_batch <- function(network, optimizer, batch, config) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( - torch::torch_split(output[[1]], outcome_nlevels, dim = 2), + torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt)) @@ -276,7 +278,7 @@ train_batch <- function(network, optimizer, batch, config) { # use `resolved_loss` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( - torch::torch_split(output[[1]], outcome_nlevels, dim = 2), + torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), ~config$loss_fn(.x, .y$squeeze(2)) @@ -286,13 +288,13 @@ train_batch <- function(network, optimizer, batch, config) { } else { if (batch$y$dtype == torch::torch_long()) { # classifier needs a squeeze for bce loss - loss <- config$loss_fn(output[[1]], batch$y$squeeze(2)) + loss <- config$loss_fn(out, batch$y$squeeze(2)) } else { - loss <- config$loss_fn(output[[1]], batch$y) + loss <- config$loss_fn(out, batch$y) } } # Add the overall sparsity loss - loss <- loss - config$lambda_sparse * output[[2]] + loss <- loss - config$lambda_sparse * M_loss # step of the optimization optimizer$zero_grad() @@ -308,8 +310,10 @@ train_batch <- function(network, optimizer, batch, config) { } valid_batch <- function(network, batch, config) { + # NULLing values to avoid a R-CMD Check Note "No visible binding for global variable" + out <- M_loss <- NULL # forward pass - output <- network(batch$x, batch$x_na_mask) + c(out, M_loss) %<-% network(batch$x, batch$x_na_mask) # loss has to be applied to each label-group when output_dim is a vector if (max(batch$output_dim$shape) > 1) { # multi-outcome @@ -318,7 +322,7 @@ valid_batch <- function(network, batch, config) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( - torch::torch_split(output[[1]], outcome_nlevels, dim = 2), + torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt)) @@ -328,7 +332,7 @@ valid_batch <- function(network, batch, config) { # use `resolved_loss` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( - torch::torch_split(output[[1]], outcome_nlevels, dim = 2), + torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), ~config$loss_fn(.x, .y$squeeze(2)) @@ -338,13 +342,13 @@ valid_batch <- function(network, batch, config) { } else { if (batch$y$dtype == torch::torch_long()) { # classifier needs a squeeze for bce loss - loss <- config$loss_fn(output[[1]], batch$y$squeeze(2)) + loss <- config$loss_fn(out, batch$y$squeeze(2)) } else { - loss <- config$loss_fn(output[[1]], batch$y) + loss <- config$loss_fn(out, batch$y) } } # Add the overall sparsity loss - loss <- loss - config$lambda_sparse * output[[2]] + loss <- loss - config$lambda_sparse * M_loss list( loss = loss$item() diff --git a/R/tab-network.R b/R/tab-network.R index a8ffc721..81767e2d 100644 --- a/R/tab-network.R +++ b/R/tab-network.R @@ -305,15 +305,14 @@ tabnet_pretrainer <- torch::nn_module( embedded_x_na_mask <- self$embedder_na(x_na_mask) if (self$training) { - masker_out_lst <- self$masker(embedded_x, embedded_x_na_mask) - obf_vars <- masker_out_lst[[2]] + c(masked_x, obfuscated_vars) %<-% self$masker(embedded_x, embedded_x_na_mask) # set prior of encoder as !obf_mask - prior <- obf_vars$logical_not() - steps_out <- self$encoder(masker_out_lst[[1]], prior)[[3]] + prior <- obfuscated_vars$logical_not() + steps_out <- self$encoder(masked_x, prior)[[3]] res <- self$decoder(steps_out) list(res, embedded_x, - obf_vars) + obfuscated_vars) } else { prior <- embedded_x_na_mask$logical_not() steps_out <- self$encoder(embedded_x, prior)[[3]] @@ -383,12 +382,9 @@ tabnet_no_embedding <- torch::nn_module( }, forward = function(x, x_na_mask) { prior <- x_na_mask$logical_not() - self_encoder_lst <- self$encoder(x, prior) - steps_output <- self_encoder_lst[[1]] - M_loss <- self_encoder_lst[[2]] - res <- torch::torch_sum(torch::torch_stack(steps_output, dim=1), dim=1) + c(res, M_loss, steps_output) %<-% self$encoder(x, prior) if (self$is_multi_outcome) { - out <- torch::torch_stack(purrr::map(self$multi_outcome_mapping, exec, !!!res), dim=2)$squeeze(3) + out <- torch::torch_stack(purrr::map(self$multi_outcome_mapping, exec, !!!res), dim = 2)$squeeze(3) } else { out <- self$final_mapping(res) } diff --git a/R/utils-pipe.R b/R/utils-pipe.R index 66b84344..53c79fad 100644 --- a/R/utils-pipe.R +++ b/R/utils-pipe.R @@ -7,6 +7,7 @@ #' @keywords internal #' @export #' @importFrom magrittr %>% +#' @importFrom zeallot %<-% #' @usage lhs \%>\% rhs #' #' @return Returns `rhs(lhs)`. diff --git a/tests/testthat/_snaps/pretraining.md b/tests/testthat/_snaps/pretraining.md index 667e9983..8b3763d6 100644 --- a/tests/testthat/_snaps/pretraining.md +++ b/tests/testthat/_snaps/pretraining.md @@ -1,6 +1,6 @@ # print module works - An `nn_module` containing 13,894 parameters. + An `nn_module` containing 13,190 parameters. -- Modules --------------------------------------- * initial_bn: #146 parameters @@ -8,7 +8,7 @@ * embedder_na: #0 parameters * masker: #0 parameters * encoder: #10,304 parameters - * decoder: #3,160 parameters + * decoder: #2,456 parameters -- Parameters ------------------------------------ * .check: Float [1:1] diff --git a/tests/testthat/test_translations.R b/tests/testthat/test_translations.R index 586c9bbe..38924b80 100644 --- a/tests/testthat/test_translations.R +++ b/tests/testthat/test_translations.R @@ -1,5 +1,6 @@ test_that("early stopping message get translated in french", { - testthat::skip_on_ci() + # skip on linux on ci + testthat::skip_if((testthat:::on_ci() & testthat:::system_os() == "linux")) testthat::skip_on_cran() withr::with_language(lang = "fr", expect_error( @@ -12,7 +13,8 @@ test_that("early stopping message get translated in french", { }) test_that("scheduler message translated in french", { - testthat::skip_on_ci() + # skip on linux on ci + testthat::skip_if((testthat:::on_ci() & testthat:::system_os() == "linux")) testthat::skip_on_cran() withr::with_language(lang = "fr", expect_error(