-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't fit final mlp model with multiple hidden layers using tidymodels #963
Comments
Hello @naokiohno 👋 thanks for filing this issue. I rewrote the reprex for speed and to use {modeldata}. # Load libraries ----------------------------------------------------------
library(tidyverse)
library(reticulate)
#> Warning: package 'reticulate' was built under R version 4.4.1
library(tidymodels)
# Load data ---------------------------------------------------------------
grant_train <- modeldata::grants_other %>%
slice(1:1000) %>%
replace(is.na(.), 0) %>%
select(class, where(is.numeric))
grant_test <- modeldata::grants_test %>%
slice(1:1000) %>%
replace(is.na(.), 0) %>%
select(class, where(is.numeric))
# Build neural net --------------------------------------------------------
# Create recipe
nnet_rec <- recipe(class ~ ., data = grant_train) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
# Create hyperparameter tuning values
grid_hidden_units <- tribble(
~hidden_units,
c(4, 2),
c(6, 4),
c(8, 4),
c(10, 5),
c(7, 4),
)
grid_penalty <- tibble(penalty = c(0.01, 0.02))
grid <- grid_hidden_units %>%
crossing(grid_penalty)
# Create model
nnet_mod <- mlp(
hidden_units = tune(),
penalty = tune(),
) %>%
set_engine("brulee") %>%
set_mode("classification")
folds <- vfold_cv(grant_train, v = 2, repeats = 1)
nnet_wflow <-
workflow() %>%
add_recipe(nnet_rec) %>%
add_model(nnet_mod)
nnet_tune_mod <- tune_grid(nnet_wflow, resamples = folds, grid = grid)
#> → A | warning: Loss is NaN at epoch 3. Training is stopped.
#> There were issues with some computations A: x1
#> There were issues with some computations A: x1
#>
nnet_best <- nnet_tune_mod %>%
select_best(metric = "roc_auc")
nnet_wflow <- finalize_workflow(nnet_wflow, nnet_best)
nnet_wflow
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: mlp()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 2 Recipe Steps
#>
#> • step_zv()
#> • step_normalize()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Single Layer Neural Network Model Specification (classification)
#>
#> Main Arguments:
#> hidden_units = list(c(8, 4))
#> penalty = 0.01
#>
#> Computational engine: brulee
final_fit <- fit(nnet_wflow, grant_train)
#> Error in `check_integer()`:
#> ! brulee_mlp() expected 'hidden_units' to be integer. What appears to be happening is that |
In the meantime, I managed to work around this issue by specifying the model hyperparameters manually and training the model with the direct 'brulee_mlp' function outside of a workflow. A separate recipe of only the preprocessing steps is required for this. However, this solution is not nearly as elegant as calling the completed workflow on the training dataset. # Create recipe of only the preprocessing steps
nnet_rec_preproc <- recipe(class ~ ., data = grant_train) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
# Specify and fit the best model manually
final_mod <- brulee_mlp(
nnet_rec_preproc,
hidden_units = c(8, 4),
penalty = 0.01,
data = grant_train
) |
The problem
I'm having trouble fitting the best cross-validated mlp model on my entire training set when the model has multiple hidden layers. This is using the tidymodels framework. This didn't seem to be a problem during the cross-validation process, but during the last model training, I'm getting the following error:
Error in
check_integer()
:! brulee_mlp() expected 'hidden_units' to be integer.
Run
rlang::last_trace()
to see where the error occurred.Indeed, hidden units is not an integer, but a vector.
Example
The text was updated successfully, but these errors were encountered: