Skip to content

Commit

Permalink
Switch to zeallot %<-% for list ouput (#133)
Browse files Browse the repository at this point in the history
* pre-test move to zeallot

* res is already computed and part of encoder result

* fix the snapshot update coming from new values for `num_xx_decoder`

* switch CI to ubuntu 20.04 for GPU Image availabilty

* fix translation cannot run on CI as GPU R in not multi-lingual
  • Loading branch information
cregouby authored Dec 4, 2023
1 parent 8beb690 commit f5b2f6f
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
name: 'gpu'

container:
image: 'nvidia/cuda:11.6.2-cudnn8-devel-ubuntu18.04'
image: 'nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04'
options: '--gpus all --runtime=nvidia'

timeout-minutes: 120
Expand Down
7 changes: 6 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
URL: https://mlverse.github.io/tabnet/, https://github.com/mlverse/tabnet
BugReports: https://github.com/mlverse/tabnet/issues
Depends:
R (>= 3.6)
Imports:
torch (>= 0.4.0),
hardhat (>= 1.3.0),
Expand All @@ -29,7 +31,8 @@ Imports:
tibble,
tidyr,
coro,
vctrs
vctrs,
zeallot
Suggests:
testthat (>= 3.0.0),
data.tree,
Expand All @@ -52,3 +55,5 @@ Suggests:
yardstick
VignetteBuilder: knitr
Config/testthat/edition: 3
Config/testthat/parallel: false
Config/testthat/start-first: interface, explain, params
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ importFrom(rlang,.data)
importFrom(stats,predict)
importFrom(stats,update)
importFrom(tidyr,replace_na)
importFrom(zeallot,"%<-%")
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# tabnet (development version)

## New features
* add FR translation (#131)
* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129)

* {tabnet} now allows hierarchical multi-label classification through {data.tree} hierarchical `Node` dataset. (#126)
* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129)
* Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120)
* use zeallot internally with %<-% for code readability (#133)
* add FR translation (#131)

# tabnet 0.4.0

Expand Down
8 changes: 5 additions & 3 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,21 @@ explain_impl <- function(network, x, x_na_mask) {
network$to(device = curr_device)
})
network$to(device=x$device)
outputs <- network$forward_masks(x, x_na_mask)
# NULLing values to avoid a R-CMD Check Note "No visible binding for global variable"
M_explain_emb_dim <- masks_emb_dim <- NULL
c(M_explain_emb_dim, masks_emb_dim) %<-% network$forward_masks(x, x_na_mask)

# summarize the categorical embeddedings into 1 column
# per variable
M_explain <- sum_embedding_masks(
mask = outputs[[1]],
mask = M_explain_emb_dim,
input_dim = network$input_dim,
cat_idx = network$cat_idxs,
cat_emb_dim = network$cat_emb_dim
)

masks <- lapply(
outputs[[2]],
masks_emb_dim,
FUN = sum_embedding_masks,
input_dim = network$input_dim,
cat_idx = network$cat_idxs,
Expand Down
28 changes: 16 additions & 12 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,10 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) {
}

train_batch <- function(network, optimizer, batch, config) {
# NULLing values to avoid a R-CMD Check Note "No visible binding for global variable"
out <- M_loss <- NULL
# forward pass
output <- network(batch$x, batch$x_na_mask)
c(out, M_loss) %<-% network(batch$x, batch$x_na_mask)
# if target is multi-outcome, loss has to be applied to each label-group
if (max(batch$output_dim$shape) > 1) {
# multi-outcome
Expand All @@ -266,7 +268,7 @@ train_batch <- function(network, optimizer, batch, config) {
# hierarchical mandates use of `max_constraint_output`
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(output[[1]], outcome_nlevels, dim = 2),
torch::torch_split(out, outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt))
Expand All @@ -276,7 +278,7 @@ train_batch <- function(network, optimizer, batch, config) {
# use `resolved_loss`
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(output[[1]], outcome_nlevels, dim = 2),
torch::torch_split(out, outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(.x, .y$squeeze(2))
Expand All @@ -286,13 +288,13 @@ train_batch <- function(network, optimizer, batch, config) {
} else {
if (batch$y$dtype == torch::torch_long()) {
# classifier needs a squeeze for bce loss
loss <- config$loss_fn(output[[1]], batch$y$squeeze(2))
loss <- config$loss_fn(out, batch$y$squeeze(2))
} else {
loss <- config$loss_fn(output[[1]], batch$y)
loss <- config$loss_fn(out, batch$y)
}
}
# Add the overall sparsity loss
loss <- loss - config$lambda_sparse * output[[2]]
loss <- loss - config$lambda_sparse * M_loss

# step of the optimization
optimizer$zero_grad()
Expand All @@ -308,8 +310,10 @@ train_batch <- function(network, optimizer, batch, config) {
}

valid_batch <- function(network, batch, config) {
# NULLing values to avoid a R-CMD Check Note "No visible binding for global variable"
out <- M_loss <- NULL
# forward pass
output <- network(batch$x, batch$x_na_mask)
c(out, M_loss) %<-% network(batch$x, batch$x_na_mask)
# loss has to be applied to each label-group when output_dim is a vector
if (max(batch$output_dim$shape) > 1) {
# multi-outcome
Expand All @@ -318,7 +322,7 @@ valid_batch <- function(network, batch, config) {
# hierarchical mandates use of `max_constraint_output`
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(output[[1]], outcome_nlevels, dim = 2),
torch::torch_split(out, outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt))
Expand All @@ -328,7 +332,7 @@ valid_batch <- function(network, batch, config) {
# use `resolved_loss`
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(output[[1]], outcome_nlevels, dim = 2),
torch::torch_split(out, outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(.x, .y$squeeze(2))
Expand All @@ -338,13 +342,13 @@ valid_batch <- function(network, batch, config) {
} else {
if (batch$y$dtype == torch::torch_long()) {
# classifier needs a squeeze for bce loss
loss <- config$loss_fn(output[[1]], batch$y$squeeze(2))
loss <- config$loss_fn(out, batch$y$squeeze(2))
} else {
loss <- config$loss_fn(output[[1]], batch$y)
loss <- config$loss_fn(out, batch$y)
}
}
# Add the overall sparsity loss
loss <- loss - config$lambda_sparse * output[[2]]
loss <- loss - config$lambda_sparse * M_loss

list(
loss = loss$item()
Expand Down
16 changes: 6 additions & 10 deletions R/tab-network.R
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,14 @@ tabnet_pretrainer <- torch::nn_module(
embedded_x_na_mask <- self$embedder_na(x_na_mask)

if (self$training) {
masker_out_lst <- self$masker(embedded_x, embedded_x_na_mask)
obf_vars <- masker_out_lst[[2]]
c(masked_x, obfuscated_vars) %<-% self$masker(embedded_x, embedded_x_na_mask)
# set prior of encoder as !obf_mask
prior <- obf_vars$logical_not()
steps_out <- self$encoder(masker_out_lst[[1]], prior)[[3]]
prior <- obfuscated_vars$logical_not()
steps_out <- self$encoder(masked_x, prior)[[3]]
res <- self$decoder(steps_out)
list(res,
embedded_x,
obf_vars)
obfuscated_vars)
} else {
prior <- embedded_x_na_mask$logical_not()
steps_out <- self$encoder(embedded_x, prior)[[3]]
Expand Down Expand Up @@ -383,12 +382,9 @@ tabnet_no_embedding <- torch::nn_module(
},
forward = function(x, x_na_mask) {
prior <- x_na_mask$logical_not()
self_encoder_lst <- self$encoder(x, prior)
steps_output <- self_encoder_lst[[1]]
M_loss <- self_encoder_lst[[2]]
res <- torch::torch_sum(torch::torch_stack(steps_output, dim=1), dim=1)
c(res, M_loss, steps_output) %<-% self$encoder(x, prior)
if (self$is_multi_outcome) {
out <- torch::torch_stack(purrr::map(self$multi_outcome_mapping, exec, !!!res), dim=2)$squeeze(3)
out <- torch::torch_stack(purrr::map(self$multi_outcome_mapping, exec, !!!res), dim = 2)$squeeze(3)
} else {
out <- self$final_mapping(res)
}
Expand Down
1 change: 1 addition & 0 deletions R/utils-pipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#' @keywords internal
#' @export
#' @importFrom magrittr %>%
#' @importFrom zeallot %<-%
#' @usage lhs \%>\% rhs
#'
#' @return Returns `rhs(lhs)`.
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/pretraining.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# print module works

An `nn_module` containing 13,894 parameters.
An `nn_module` containing 13,190 parameters.

-- Modules ---------------------------------------
* initial_bn: <nn_batch_norm1d> #146 parameters
* embedder: <embedding_generator> #283 parameters
* embedder_na: <na_embedding_generator> #0 parameters
* masker: <random_obfuscator> #0 parameters
* encoder: <tabnet_encoder> #10,304 parameters
* decoder: <tabnet_decoder> #3,160 parameters
* decoder: <tabnet_decoder> #2,456 parameters

-- Parameters ------------------------------------
* .check: Float [1:1]
Expand Down
6 changes: 4 additions & 2 deletions tests/testthat/test_translations.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
test_that("early stopping message get translated in french", {
testthat::skip_on_ci()
# skip on linux on ci
testthat::skip_if((testthat:::on_ci() & testthat:::system_os() == "linux"))
testthat::skip_on_cran()
withr::with_language(lang = "fr",
expect_error(
Expand All @@ -12,7 +13,8 @@ test_that("early stopping message get translated in french", {
})

test_that("scheduler message translated in french", {
testthat::skip_on_ci()
# skip on linux on ci
testthat::skip_if((testthat:::on_ci() & testthat:::system_os() == "linux"))
testthat::skip_on_cran()
withr::with_language(lang = "fr",
expect_error(
Expand Down

0 comments on commit f5b2f6f

Please sign in to comment.