Skip to content

Commit

Permalink
Restructured torch modules to support shapr installation without torch (
Browse files Browse the repository at this point in the history
#393)

* Updated the mask generators

* Fixed the Neural Network Modules

* Fixed Dataset Utility Functions

* Added check that progressr is installed, otherwise we proceed without a progress_bar

* Fixed vaeac and memory layer

* Missed default value in `mcar_mask_generator`.

* Manuals

* Typos in the vaeac vignettes

* Updated the documentation to clearly state that a vaeac model cannot be moved from the folder it was trained in if one want to continue to train it. This is a limitation that I should consider fixing. But I am unsure how often this will occur.

Also made sure that continue train works if the explanation object was trained by giving a path.

* styler + lintr

* Added self as global variable
  • Loading branch information
LHBO authored Apr 18, 2024
1 parent 5b15935 commit ddd32c7
Show file tree
Hide file tree
Showing 8 changed files with 1,270 additions and 1,174 deletions.
22 changes: 16 additions & 6 deletions R/approach_vaeac.R
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,11 @@ vaeac_train_model <- function(x_train,
best_vlb <- -Inf

# Create a `progressr::progressor()` to keep track of the overall training time of the vaeac approach
progressr_bar <- progressr::progressor(steps = epochs_initiation_phase * (n_vaeacs_initialize - 1) + epochs)
if (requireNamespace("progressr", quietly = TRUE)) {
progressr_bar <- progressr::progressor(steps = epochs_initiation_phase * (n_vaeacs_initialize - 1) + epochs)
} else {
progressr_bar <- NULL
}

# Iterate over the initializations.
initialization_idx <- 1
Expand Down Expand Up @@ -835,9 +839,10 @@ vaeac_train_model_continue <- function(explanation,
# Set seed for reproducibility
set.seed(seed)

# Extract the vaeac list and load the model at the last epoch
# Extract the vaeac list and load the model at the last epoch or the best (default 'best' when path is provided)
vaeac_model <- explanation$internal$parameters$vaeac
checkpoint <- torch::torch_load(vaeac_model$models$last)
vaeac_model_path <- if (!is.null(vaeac_model$models$last)) vaeac_model$models$last else vaeac_model$models$best
checkpoint <- torch::torch_load(vaeac_model_path)

# Get which device we are to continue to train the model
device <- ifelse(checkpoint$cuda, "cuda", "cpu")
Expand Down Expand Up @@ -939,7 +944,11 @@ vaeac_train_model_continue <- function(explanation,
state_list$epochs <- epochs

# Create a `progressr::progressor()` to keep track of the new training
progressr_bar <- progressr::progressor(steps = epochs_new)
if (requireNamespace("progressr", quietly = TRUE)) {
progressr_bar <- progressr::progressor(steps = epochs_new)
} else {
progressr_bar <- NULL
}

# Train the vaeac model for `epochs_new` number of epochs
vaeac_tmp <- vaeac_train_model_auxiliary(
Expand Down Expand Up @@ -1617,8 +1626,9 @@ vaeac_check_parameters <- function(x_train,
#' then a name will be generated based on [base::Sys.time()] to ensure a unique name. We use [base::make.names()] to
#' ensure a valid file name for all operating systems.
#' @param vaeac.folder_to_save_model String (default is [base::tempdir()]). String specifying a path to a folder where
#' the function is to save the fitted vaeac model. Note that the path will be removed from the returned
#' [shapr::explain()] object if `vaeac.save_model = FALSE`.
#' the function is to save the fitted vaeac model. Note that the path will be removed from the returned
#' [shapr::explain()] object if `vaeac.save_model = FALSE`. Furthermore, the model cannot be moved from its
#' original folder if we are to use the [shapr::vaeac_train_model_continue()] function to continue training the model.
#' @param vaeac.pretrained_vaeac_model List or String (default is `NULL`). 1) Either a list of class
#' `vaeac`, i.e., the list stored in `explanation$internal$parameters$vaeac` where `explanation` is the returned list
#' from an earlier call to the [shapr::explain()] function. 2) A string containing the path to where the `vaeac`
Expand Down
Loading

0 comments on commit ddd32c7

Please sign in to comment.