Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the separate and surrogate regression method classes #379

Merged
merged 399 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
399 commits
Select commit Hold shift + click to select a range
0e9d773
update docs
martinju Feb 15, 2024
c350892
update test files
martinju Feb 15, 2024
87f1a53
accept also vaeac tests
martinju Feb 15, 2024
7abd4d8
Started shapr vaeac vignette skip build. Need to change its filename …
LHBO Feb 15, 2024
0ef4b65
Test if it works now
LHBO Feb 15, 2024
1f120f7
to run properly on GHA which only have 4 cores
martinju Feb 16, 2024
dbf01e5
Merge branch 'Lars/VAEAC_SHAPR' of github.com:LHBO/shapr into Lars/VA…
martinju Feb 16, 2024
72db0d8
fix long names in vignette cache, remove formalArgs ++
martinju Feb 16, 2024
b1fdd58
move long running vignette script out of R/
martinju Feb 16, 2024
bb2a85a
.
martinju Feb 16, 2024
10fb862
actually run the chunks in the vignettes
martinju Feb 16, 2024
68fb88d
vignette figures + more name fixing
martinju Feb 16, 2024
25db7b3
gitignore vignette cache. Will do it manually locally anyway
martinju Feb 16, 2024
cd0df49
more vignette file updates
martinju Feb 16, 2024
5d28b8a
put the original vignette into the long-vignette format as well
martinju Feb 16, 2024
4a3b4aa
fix typo making tests on GHA fail
martinju Feb 16, 2024
7e5bec3
forgot to update rds files
martinju Feb 16, 2024
70a3048
Added main vignette to the .Rbuildignore file too
LHBO Feb 16, 2024
b7a85b3
Added vignettes figure and cache map to buildignore
LHBO Feb 16, 2024
4060144
Updated what to remove from shapr object when not saving vaeac model.…
LHBO Feb 16, 2024
93311fa
Updated questions before publising on cran
LHBO Feb 16, 2024
b61ac8f
Updated snapshots for vaeac test functions
LHBO Feb 16, 2024
96e87a0
Set separate figure and cache paths for the vignettes
LHBO Feb 16, 2024
4ce7e98
Updated broken links in vaeac vignette
LHBO Feb 16, 2024
31baa9c
Updated the figure parameters
LHBO Feb 16, 2024
6be1ee0
Run the vignettes and deleted old files
LHBO Feb 16, 2024
e4e1b58
stylr
LHBO Feb 16, 2024
d8aecdf
Updated rbuildignore. Forgot to update after I split the vignettes in…
LHBO Feb 19, 2024
7f9e81c
Adding script to show that getbatch is much faster than getitem
LHBO Feb 19, 2024
d1622bf
Updated from getitem to getbatch and updated some comments
LHBO Feb 19, 2024
ce0e39d
Updated the roxygen and comments of the torch modules
LHBO Feb 19, 2024
3849779
Typo
LHBO Feb 19, 2024
c481d02
Reran building the vaeac vignette.
LHBO Feb 19, 2024
4fe7b66
Updates needed for GPU to run
Feb 20, 2024
c68e5c4
Changed from vaeac_continue_train_model to vaeac_train_model_continue…
Feb 20, 2024
fb9b204
Update the prinouts to state what unit the model is using (CPU or GPU).
Feb 20, 2024
28567ea
More GPU fixes
Feb 20, 2024
94e579a
Fixed such that GPU VAEAC can continue training.
Feb 20, 2024
8ffc5f9
Updated the vignette
Feb 20, 2024
8ad5c4a
styler+lintr
Feb 20, 2024
e79e12d
rerun vignette with messages ++ (Vignettes must be recompiled with GPU!)
martinju Feb 21, 2024
9061e6c
exclude all vignettes from lintr
martinju Feb 21, 2024
f7cc967
remove purl call
martinju Feb 21, 2024
f429cb7
Fixed typo and too wide lines
LHBO Feb 21, 2024
da6fe6a
Removed unnecessary `keep_samp_for_vS`
LHBO Feb 22, 2024
3de36ac
Adding logical indicating if we are to use regression and regression …
LHBO Feb 22, 2024
3d83ff1
Split `batch_compute_vS()` based on regression or MC
LHBO Feb 22, 2024
0c7ac84
Added `batch_prepare_vS_regression()` which ensures that the grand co…
LHBO Feb 22, 2024
e373f6b
Added the regression_separate approach. Works without CV. Need to che…
LHBO Feb 22, 2024
e4fdad2
Added tidymodels to DESCRIPTION. NEED TO FIND OUT WHAT TO KEEP
LHBO Feb 23, 2024
1534ef8
ADDED TO regression stuff to ignore
LHBO Feb 23, 2024
15034bd
Added tidymodels to references
LHBO Feb 23, 2024
141a8da
Use general prepare_data function
LHBO Feb 23, 2024
3fbc300
Updated cross-validation
LHBO Feb 23, 2024
95888cc
Updated gitignore
LHBO Feb 23, 2024
0470815
Added manuals
LHBO Feb 23, 2024
2d00e67
Added vignette
LHBO Feb 23, 2024
086ce5b
Updated rebuild_long_running_vginette.R
LHBO Feb 23, 2024
cdea349
Added vignette figure folder
LHBO Feb 23, 2024
1d93d1d
Small vignette update
LHBO Feb 23, 2024
627614e
styler + lintr
LHBO Feb 23, 2024
e31bf77
styler + lintr
LHBO Feb 23, 2024
d91460a
Typo in package names in Description
LHBO Feb 26, 2024
bbf4a7d
Added new dependencies in DESCRITPION
LHBO Feb 26, 2024
bcc3c2c
Fixed non-working approach check related to that regression approach …
LHBO Feb 26, 2024
3413738
Updated to support surrogate regression too
LHBO Feb 27, 2024
a0e497b
Missing comma in DESCRIPTION
LHBO Feb 27, 2024
ac9acc1
Tidied up reg sep
LHBO Feb 27, 2024
70fa894
Added first version of reg sur
LHBO Feb 27, 2024
7b9ac8e
Shortened the code
LHBO Feb 27, 2024
4cfa622
Renamed objects to seperate sur and sep
LHBO Feb 27, 2024
794887b
Shortened code in explain
LHBO Feb 27, 2024
415450a
Added roxygen documentation
LHBO Feb 27, 2024
6befc84
Added check for regression in setup
LHBO Feb 28, 2024
4191c94
Fixed bug and added made shorter
LHBO Feb 28, 2024
8993845
Refactored surrogate
LHBO Feb 28, 2024
fd9ea1d
Started refactor separate
LHBO Feb 28, 2024
0e00829
Small visual differences
LHBO Feb 28, 2024
e89952a
Updated manuals
LHBO Feb 29, 2024
38a6a8e
Updated separate regression works now
LHBO Feb 29, 2024
69f6c29
Changed printing
LHBO Feb 29, 2024
b666bd2
Starting to MAYBE add NN_OLSEN
LHBO Feb 29, 2024
eeb6c26
Fixed bug where it did not work for grand comb in separate batch
LHBO Feb 29, 2024
5612fdc
Stopped using pipes %>% to make the code work in parallel using futur…
LHBO Feb 29, 2024
6e69e22
Created own function in explain that removes the unneeded stuf from t…
LHBO Feb 29, 2024
e38d4dc
Prettier prinouts
LHBO Feb 29, 2024
faedbb5
Removed comments
LHBO Feb 29, 2024
b828ef7
Moved from right to left. I do not think I will ever be completely ha…
LHBO Feb 29, 2024
901e05c
Forgot to change function name
LHBO Mar 1, 2024
be24df5
Updated vignette
LHBO Mar 1, 2024
92458be
Added testing of ppr. DELETE THIS FILE IN THE END
LHBO Mar 4, 2024
963855b
Added PPR as a new method that gets loaded together with SHAPR. TEST …
LHBO Mar 5, 2024
6ba16ae
Started on potentially making tidymodels/workflow as a potential mode…
LHBO Mar 5, 2024
9c4d68a
Removed old file
LHBO Mar 5, 2024
16f24e1
Update such that PPR gets loaded together with shapr
LHBO Mar 5, 2024
3211178
Moved where to load ppr
LHBO Mar 5, 2024
73306b9
Find out if I should remove the ppr approach
LHBO Mar 5, 2024
5bbe272
Added example of how to include new model in vignette
LHBO Mar 5, 2024
e3fcdf5
Merge master into branch part one
LHBO Mar 6, 2024
08ec72e
Updated regression train for surrogate regression to improve the cros…
LHBO Mar 8, 2024
15fcf03
Changed to alwyas include the grand coalition when training surrogate…
LHBO Mar 8, 2024
ac65a20
Updated the vignette
LHBO Mar 8, 2024
6be6eb4
Vignette works. Need some polishing though
LHBO Mar 8, 2024
e946fd8
Styler + lintr
LHBO Mar 8, 2024
781c025
Added option for `regression_tune_values` to be a function. Needed fo…
LHBO Mar 11, 2024
c1b2fd3
Fix bug in vaeac
LHBO Mar 11, 2024
8677c6a
Staretd looking at NN_Olsen
LHBO Mar 11, 2024
67e85d8
Added mixed data to the vignette
LHBO Mar 11, 2024
46b0910
Fixed bug related to separate/surrogate regression when using several…
LHBO Mar 12, 2024
6a91ce7
Added file that tests all the parameters for the regression approaches.
LHBO Mar 12, 2024
8b024fb
Added file with test functions for the output of the regression methods
LHBO Mar 12, 2024
ab6170c
Added more test and fixed prefix to `tune()` function
LHBO Mar 12, 2024
d9135f1
Added test to surrogate regresssion on `regresserion_surr_n_comb`
LHBO Mar 12, 2024
633ab92
Added vignette figures
LHBO Mar 12, 2024
9c54741
Added more examples to the vignette
LHBO Mar 12, 2024
97f2385
Include prefix of tune() and remove the pipe-function
LHBO Mar 12, 2024
6f5d3e1
Removed files
LHBO Mar 12, 2024
a3bc72e
Updated description
LHBO Mar 12, 2024
a50f659
Added regression to rebuild long running vignette
LHBO Mar 12, 2024
3d96fad
styler + lintr
LHBO Mar 12, 2024
c02e5bf
Added TOC to vignette
LHBO Mar 12, 2024
15a2367
Removed make_ppr_reg() from onLoad
LHBO Mar 12, 2024
8133885
Update vignette plots
LHBO Mar 12, 2024
62c64c8
Updated the vignette to the .orig style
LHBO Mar 12, 2024
8e3a727
Updated manuals 1
LHBO Mar 12, 2024
e1a9e15
Fixed "Found if() conditions comparing class() to string:".
LHBO Mar 12, 2024
06ae057
Changed namespace
LHBO Mar 12, 2024
0751063
added manuals 2
LHBO Mar 12, 2024
cc3b109
Renamed model to workflow and and finalized the model functions. NEED…
LHBO Mar 13, 2024
e76a550
Updated the manuals
LHBO Mar 13, 2024
ba16a3d
Added workflows:::fit.workflow, but I am unsure if I should.
LHBO Mar 13, 2024
47402c4
Fixed warnings in check
LHBO Mar 13, 2024
fa34acd
Updated the vignette
LHBO Mar 13, 2024
84687e4
Updated model_workflow to work with recipies and fixed predict function
LHBO Mar 14, 2024
f4bdd0d
Added function where I experimented with the model workflow.
LHBO Mar 14, 2024
5b266ab
Updated vignettes
LHBO Mar 14, 2024
d0adcb2
Updated libraries
LHBO Mar 14, 2024
98cbf76
Fixed ::: check message
LHBO Mar 14, 2024
7a10b40
Updated Manuals
LHBO Mar 14, 2024
f53ce16
Fixed namespace check in the regression approaches
LHBO Mar 14, 2024
ad4312e
Fixed Undefined global functions or variables
LHBO Mar 14, 2024
0151f11
Added tests for regression
LHBO Mar 14, 2024
70cef48
Fixed bug with surrogate when all features were categorical
LHBO Mar 14, 2024
3e31a55
Fixed bug with regression_recipe_func
LHBO Mar 14, 2024
e0ff09d
Fixed vignette
LHBO Mar 14, 2024
c590a58
Updated the test snapshots for regression methods
LHBO Mar 14, 2024
827a4a5
Remove `regression` from output for the MC methods
LHBO Mar 14, 2024
1515247
Surrogate save the augmented training data
LHBO Mar 14, 2024
c3bfa23
Make it optional to store the augmented training data
LHBO Mar 14, 2024
e22df68
Removed x_augment from the snapfiles (took too much space)
LHBO Mar 14, 2024
6bc4de7
Shorten the test names
LHBO Mar 14, 2024
f05f5f8
Go back to the way it was before with approch error message
LHBO Mar 14, 2024
93fac47
Manuals
LHBO Mar 14, 2024
0c9c0e0
Merged master into branch
LHBO Mar 14, 2024
fa9642c
Styler
LHBO Mar 14, 2024
afc758e
Typo
LHBO Mar 14, 2024
c86f50b
Remove regression from explain_forecast too
LHBO Mar 14, 2024
fb2feba
Added check for forecast
LHBO Mar 14, 2024
20334db
Update the snaps with correct answers
LHBO Mar 14, 2024
3352928
Update explain forecast to remove objects from output
LHBO Mar 14, 2024
4488193
Uppdate printout for approaches
LHBO Mar 14, 2024
bfec276
Manuals
LHBO Mar 14, 2024
92051eb
Delete man/MCAR_mask_generator.Rd
LHBO Mar 14, 2024
cb6554d
Delete man/Specified_masks_mask_generator.Rd
LHBO Mar 14, 2024
bba2d05
Delete man/Specified_prob_mask_generator.Rd
LHBO Mar 14, 2024
e596ac1
Renamed manuals
LHBO Mar 14, 2024
c6d4f99
Fixed manual file names now?
LHBO Mar 14, 2024
28a5bdc
Enabling all tests
LHBO Mar 14, 2024
3047b09
Update R-CMD-check.yaml
martinju Mar 15, 2024
2acbe61
Fixed such that we get the same surrogate model independent if the CV…
LHBO Mar 18, 2024
e766320
Ensure that we can parlellize both training and prediction step in su…
LHBO Mar 18, 2024
8ba57a8
Updated the vignette
LHBO Mar 18, 2024
6d504cd
Update 1 snap as it changed with the extra set.seed
LHBO Mar 18, 2024
4336f80
Update manuals
LHBO Mar 18, 2024
84c171e
Styler + lintr
LHBO Mar 18, 2024
16139d9
Fixed R-CMD to match how Martin did it
LHBO Mar 18, 2024
1b59fac
Roxygen
LHBO Mar 18, 2024
b13d747
Specify that the regression approaches are only applicable when calle…
LHBO Mar 18, 2024
ef82612
Added test related to regression approaches and python
LHBO Mar 18, 2024
e35cdd2
set bash shell on windows
martinju Mar 20, 2024
9ca5424
delete unwanted cache files
martinju Apr 2, 2024
05ac0a5
Changed from `regression_` to `regression.` in function parameters.
LHBO Apr 3, 2024
cd0a09d
Removed option to save x_augmented for regression_surrogate
LHBO Apr 3, 2024
fef0f6f
Chanhed so all regression functions start with `regression.`
LHBO Apr 3, 2024
3fd76f4
`regression_` to `regression.` inside functions too for consistency
LHBO Apr 3, 2024
3d1a1e5
Went back to old code design (several lines vs one line)
LHBO Apr 4, 2024
dc0ecff
Refactored `batch_compute_vS()`
LHBO Apr 4, 2024
96b22ad
Uppdated warning message
LHBO Apr 4, 2024
f678bb4
Added `@inheritDotParams setup_approach.regression_`
LHBO Apr 4, 2024
522707a
Added missing parameter in `batch_prepare_vS_MC`.
LHBO Apr 4, 2024
3d6ca2c
Fixed names
LHBO Apr 4, 2024
1664bca
Fixed `tune::select_best()` error with `metric` not beeing specified.
LHBO Apr 4, 2024
c5dd174
Add `metric = 'remse'` to `tune::show_best()` to not get a warning.
LHBO Apr 4, 2024
176261b
Changed main vignette such that `explanation` is not overwritten, and…
LHBO Apr 4, 2024
6732326
Added workflows example in the main vignette
LHBO Apr 8, 2024
2a7f0e1
Update the vignette based on Martin's feedback
LHBO Apr 9, 2024
8aaff55
Styler
LHBO Apr 9, 2024
6dcb264
Typo
LHBO Apr 9, 2024
bba4757
Typo
LHBO Apr 9, 2024
a04f388
Fixed the error?
LHBO Apr 9, 2024
556f85b
Fixed the error now?
LHBO Apr 9, 2024
5cb1f4f
Fixed bug in compute_vS.R where `dt` was not returned.
LHBO Apr 10, 2024
cebd3d7
Updated setup test snaps (changed text since we now use `regression.`…
LHBO Apr 10, 2024
c1182ed
Update snaps. Changes due to changing names.
LHBO Apr 10, 2024
c9e9804
Accept changes due to changed warning message.
LHBO Apr 10, 2024
7b804bc
Updated such that the `regression.model`,
LHBO Apr 10, 2024
41f0a90
Add string section in the Vignette
LHBO Apr 10, 2024
4d87fbf
styler
LHBO Apr 10, 2024
1fffa33
Manual updates
LHBO Apr 10, 2024
5f553ad
Rerun reg vignette
LHBO Apr 10, 2024
f4df345
Merged master into branch
LHBO Apr 11, 2024
e95d939
Update when `regression.get_y_hat` is called.
LHBO Apr 11, 2024
1d157ed
Added to documentation in model.R that we can also explain workflows/…
LHBO Apr 11, 2024
e8842d2
Removed check that said that regression did not work from python
LHBO Apr 11, 2024
50c3eed
changed `batch_prepare_vS` to `batch_prepare_vS_MC_auxiliary` to make…
LHBO Apr 12, 2024
1e9d85b
Fixed the last places that were named `regression_` instead of `regre…
LHBO Apr 12, 2024
b7228af
Added some comments
LHBO Apr 12, 2024
2c0fe7d
Changed several things in python's explain:
LHBO Apr 12, 2024
b22da04
update test snaps as I changed one parameter name
LHBO Apr 12, 2024
d59ebe5
Missed the `MSEv_g` object
LHBO Apr 12, 2024
53686b0
Small typo
LHBO Apr 12, 2024
4ff04aa
Fixed such that the `regression.vfold_cv_para` dict is converted to R…
LHBO Apr 12, 2024
c3bf277
Typo
LHBO Apr 12, 2024
d163e57
Added info about regression-based approaches in the readme in python
LHBO Apr 12, 2024
0ae63af
Added example file demonstrating how to use the regression paradigm f…
LHBO Apr 12, 2024
8bb9926
Pycharm added requirements file
LHBO Apr 12, 2024
c98bb08
Added MSEv ouput of the methods
LHBO Apr 12, 2024
f7e4ac6
Updated the manuals
LHBO Apr 12, 2024
b7f7fe9
Added section in main vignette about regression paradigm. Not compile…
LHBO Apr 12, 2024
5803e2b
Updated reference to the comparative study paper that is now published
LHBO Apr 14, 2024
4acc193
Martin: added the title of the regression vignette.
LHBO Apr 15, 2024
5cb94da
Martin: reverted to 2 indents where it was like that before.
LHBO Apr 15, 2024
24e76bf
Martin: better comments in `explain.R`.
LHBO Apr 15, 2024
d596189
Martin: deleted python/requirements.txt as both of us were unsure if …
LHBO Apr 15, 2024
048e1f7
bugfix and rerun main vignette
martinju Apr 15, 2024
1c56989
fix and rerun regression vignette
martinju Apr 15, 2024
1db3e5e
Added `explain_tripledot_docs` function to make documentation from Py…
LHBO Apr 16, 2024
3088ba6
Changed comment in py-explain
LHBO Apr 16, 2024
cff264d
Added sep and sur example in `explain` and added updated manuals.
LHBO Apr 16, 2024
de3e9b9
Updated kwargs docu in python.
LHBO Apr 16, 2024
b9b7f0e
Update manuals
LHBO Apr 16, 2024
5ba6bb9
Martin: merged his updates into mine
LHBO Apr 16, 2024
c9fd319
Typo vignette
LHBO Apr 16, 2024
cd3e69a
Fixed conflicts in reg vignette
LHBO Apr 16, 2024
8b54165
Martin: mistake in which dataset to use. Gave NA NA
LHBO Apr 16, 2024
ca2516c
Built the regression vignette (was broken after Martin commit)
LHBO Apr 16, 2024
fe1b756
Updated the main vignette too
LHBO Apr 16, 2024
79b1ee1
Added title of the regression vignette in the main vignette
LHBO Apr 16, 2024
fe9c4a5
Made shaprpy-explain easier for Martin to look at. `timedelta` is not…
LHBO Apr 16, 2024
4e751dd
Same as previous
LHBO Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ inst/compare_lundberg\.xgb\.obj
^rebuild-long-running-vignette\.R$
^vignettes/understanding_shapr_vaeac\.Rmd\.orig$
^vignettes/understanding_shapr\.Rmd\.orig$
^vignettes/understanding_shapr_regression\.Rmd\.orig$
^vignettes/figure_main/*$
^vignettes/cache_main/*$
^vignettes/figure_vaeac/*$
^vignettes/cache_vaeac/*$
^vignettes/figure_regression/*$
^vignettes/cache_regression/*$
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
fail-fast: false
matrix:
config:
# Temporary disable all but ubuntu release to reduce compute while debugging
- {os: macOS-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- {os: ubuntu-20.04, r: 'devel', http-user-agent: 'release'}
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ docs
/Meta/
.idea
.DS_Store

vignettes/cache_main/
vignettes/cache_vaeac/
vignettes/cache_regression/
11 changes: 10 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,16 @@ Suggests:
torch,
GGally,
progress,
coro
coro,
parsnip,
recipes,
workflows,
tune,
dials,
yardstick,
hardhat,
rsample,
rlang
LinkingTo:
RcppArmadillo,
Rcpp
Expand Down
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ S3method(get_model_specs,gam)
S3method(get_model_specs,glm)
S3method(get_model_specs,lm)
S3method(get_model_specs,ranger)
S3method(get_model_specs,workflow)
S3method(get_model_specs,xgb.Booster)
S3method(model_checker,Arima)
S3method(model_checker,ar)
Expand All @@ -16,6 +17,7 @@ S3method(model_checker,gam)
S3method(model_checker,glm)
S3method(model_checker,lm)
S3method(model_checker,ranger)
S3method(model_checker,workflow)
S3method(model_checker,xgb.Booster)
S3method(plot,shapr)
S3method(predict_model,Arima)
Expand All @@ -26,13 +28,16 @@ S3method(predict_model,gam)
S3method(predict_model,glm)
S3method(predict_model,lm)
S3method(predict_model,ranger)
S3method(predict_model,workflow)
S3method(predict_model,xgb.Booster)
S3method(prepare_data,categorical)
S3method(prepare_data,copula)
S3method(prepare_data,ctree)
S3method(prepare_data,empirical)
S3method(prepare_data,gaussian)
S3method(prepare_data,independence)
S3method(prepare_data,regression_separate)
S3method(prepare_data,regression_surrogate)
S3method(prepare_data,timeseries)
S3method(prepare_data,vaeac)
S3method(print,shapr)
Expand All @@ -43,6 +48,8 @@ S3method(setup_approach,ctree)
S3method(setup_approach,empirical)
S3method(setup_approach,gaussian)
S3method(setup_approach,independence)
S3method(setup_approach,regression_separate)
S3method(setup_approach,regression_surrogate)
S3method(setup_approach,timeseries)
S3method(setup_approach,vaeac)
export(aicc_full_single_cpp)
Expand All @@ -68,6 +75,7 @@ export(predict_model)
export(prepare_data)
export(prepare_data_copula_cpp)
export(prepare_data_gaussian_cpp)
export(regression.train_model)
export(rss_cpp)
export(setup)
export(setup_approach)
Expand Down
516 changes: 516 additions & 0 deletions R/approach_regression_separate.R

Large diffs are not rendered by default.

245 changes: 245 additions & 0 deletions R/approach_regression_surrogate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Shapr functions ======================================================================================================
#' @rdname setup_approach
#'
#' @inheritParams default_doc_explain
#' @inheritParams setup_approach.regression_separate
#' @param regression.surrogate_n_comb Integer (default is `internal$parameters$used_n_combinations`) specifying the
#' number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is
#' "`internal$parameters$used_n_combinations` - 2". By default, we use all coalitions, but this can take a lot of memory
#' in larger dimensions. Note that by "all", we mean all coalitions chosen by `shapr` to be used. This will be all
#' \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if `shapr` is in the exact mode. If the
#' user sets a lower value than `internal$parameters$used_n_combinations`, then we sample this amount of unique
#' coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.
#'
#' @export
#' @author Lars Henry Berge Olsen
setup_approach.regression_surrogate <- function(internal,
regression.model = parsnip::linear_reg(),
regression.tune_values = NULL,
regression.vfold_cv_para = NULL,
regression.recipe_func = NULL,
regression.surrogate_n_comb =
internal$parameters$used_n_combinations - 2,
...) {
# Check that required libraries are installed
regression.check_namespaces()

# Small printout to the user
if (internal$parameters$verbose == 2) message("Starting 'setup_approach.regression_surrogate'.")

# Add the default parameter values for the non-user specified parameters for the separate regression approach
defaults <- mget(c(
"regression.model", "regression.tune_values", "regression.vfold_cv_para",
"regression.recipe_func", "regression.surrogate_n_comb"
))
internal <- insert_defaults(internal, defaults)

# Check the parameters to the regression approach
internal <- regression.check_parameters(internal)

# Augment the training data
x_train_augmented <- regression.surrogate_aug_data(
internal = internal, x = internal$data$x_train, y_hat = internal$data$x_train_y_hat, augment_include_grand = TRUE
)

# Fit the surrogate regression model and store it in the internal list
if (internal$parameters$verbose == 2) message("Start training the surrogate model.")
internal$objects$regression.surrogate_model <- regression.train_model(
x = x_train_augmented,
seed = internal$parameters$seed,
verbose = internal$parameters$verbose,
regression.model = internal$parameters$regression.model,
regression.tune = internal$parameters$regression.tune,
regression.tune_values = internal$parameters$regression.tune_values,
regression.vfold_cv_para = internal$parameters$regression.vfold_cv_para,
regression.recipe_func = internal$parameters$regression.recipe_func,
regression.surrogate_n_comb = regression.surrogate_n_comb + 1 # Add 1 as augment_include_grand = TRUE above
)

# Small printout to the user
if (internal$parameters$verbose == 2) message("Done with 'setup_approach.regression_surrogate'.")

return(internal) # Return the updated internal list
}

#' @inheritParams default_doc
#' @rdname prepare_data
#' @export
#' @author Lars Henry Berge Olsen
prepare_data.regression_surrogate <- function(internal, index_features = NULL, ...) {
# Load `workflows`, needed when parallelized as we call predict with a workflow object. Checked installed above.
requireNamespace("workflows", quietly = TRUE)

# Small printout to the user about which batch that are currently worked on
if (internal$parameters$verbose == 2) regression.prep_message_batch(internal, index_features)

# Augment the explicand data
x_explain_aug <- regression.surrogate_aug_data(internal, x = internal$data$x_explain, index_features = index_features)

# Compute the predicted response for the explicands, i.e., v(S, x_i) for all explicands x_i and S in index_features
pred_explicand <- predict(internal$objects$regression.surrogate_model, new_data = x_explain_aug)$.pred

# Insert the predicted contribution functions values into a data table of the correct setup
dt_res <- data.table(as.integer(index_features), matrix(pred_explicand, nrow = length(index_features)))
data.table::setnames(dt_res, c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain))))
data.table::setkey(dt_res, id_combination) # Set id_combination to be the key

return(dt_res)
}

# Augment function =====================================================================================================
#' Augment the training data and the explicands
#'
#' @inheritParams default_doc
#' @inheritParams regression.train_model
#' @param y_hat Vector of numerics (optional) containing the predicted responses for the observations in `x`.
#' @param index_features Array of integers (optional) containing which coalitions to consider. Must be provided if
#' `x` is the explicands.
#' @param augment_add_id_comb Logical (default is `FALSE`). If `TRUE`, an additional column is adding containing
#' which coalition was applied.
#' @param augment_include_grand Logical (default is `FALSE`). If `TRUE`, then the grand coalition is included.
#' If `index_features` are provided, then `augment_include_grand` has no effect. Note that if we sample the
#' combinations then the grand coalition is equally likely to be samples as the other coalitions (or weighted if
#' `augment_comb_prob` is provided).
#' @param augment_masks_as_factor Logical (default is `FALSE`). If `TRUE`, then the binary masks are converted
#' to factors. If `FALSE`, then the binary masks are numerics.
#' @param augment_comb_prob Array of numerics (default is `NULL`). The length of the array must match the number of
#' combinations being considered, where each entry specifies the probability of sampling the corresponding coalition.
#' This is useful if we want to generate more training data for some specific coalitions. One possible choice would be
#' `augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_combinations] else NULL`.
#' @param augment_weights String (optional). Specifying which type of weights to add to the observations.
#' If `NULL` (default), then no weights are added. If `"Shapley"`, then the Shapley weights for the different
#' combinations are added to corresponding observations where the coalitions was applied. If `uniform`, then
#' all observations get an equal weight of one.
#'
#' @return A data.table containing the augmented data.
#' @author Lars Henry Berge Olsen
#' @keywords internal
regression.surrogate_aug_data <- function(internal,
x,
y_hat = NULL,
index_features = NULL,
augment_masks_as_factor = FALSE,
augment_include_grand = FALSE,
augment_add_id_comb = FALSE,
augment_comb_prob = NULL,
augment_weights = NULL) {
# Get some of the parameters
S <- internal$objects$S
actual_n_combinations <- internal$parameters$used_n_combinations - 2 # Remove empty and grand coalitions
regression.surrogate_n_comb <- internal$parameters$regression.surrogate_n_comb
if (!is.null(index_features)) regression.surrogate_n_comb <- length(index_features) # Applicable from prep_data()
if (augment_include_grand) {
actual_n_combinations <- actual_n_combinations + 1 # Add 1 to include the grand comb
regression.surrogate_n_comb <- regression.surrogate_n_comb + 1
}
if (regression.surrogate_n_comb > actual_n_combinations) regression.surrogate_n_comb <- actual_n_combinations

# Small checks
if (!is.null(augment_weights)) augment_weights <- match.arg(augment_weights, c("Shapley", "uniform"))

if (!is.null(augment_comb_prob) && length(augment_comb_prob) != actual_n_combinations) {
stop(paste("`augment_comb_prob` must be of length", actual_n_combinations, "."))
}

if (!is.null(augment_weights) && augment_include_grand && augment_weights == "Shapley") {
stop(paste(
"`augment_include_grand = TRUE` and `augment_weights = 'Shapley'` cannot occure",
"because this entails too large weight for the grand coalition."
))
}

# Get the number of observations (either the same as n_train or n_explain)
n_obs <- nrow(x)

# Get the names of the categorical/factor features and the continuous/non-categorical/numeric features.
feature_classes <- internal$objects$feature_specs$classes
feature_cat <- names(feature_classes)[feature_classes == "factor"]
feature_cont <- names(feature_classes)[feature_classes != "factor"]

# Get the indices of the order of the cat and cont features
feature_cat_idx <- which(names(feature_classes) %in% feature_cat)
feature_cont_idx <- which(names(feature_classes) %in% feature_cont)

# Check if we are to augment the training data or the explicands
if (is.null(index_features)) {
# Training: get matrix (n_obs x regression.surrogate_n_comb) containing the indices of the active coalitions
if (regression.surrogate_n_comb >= actual_n_combinations) { # Start from two to exclude the empty set
comb_active_idx <- matrix(rep(seq(2, actual_n_combinations + 1), times = n_obs), ncol = n_obs)
} else {
comb_active_idx <- sapply(seq(n_obs), function(x) { # Add 1 as we want to exclude the empty set
sample.int(n = actual_n_combinations, size = regression.surrogate_n_comb, prob = augment_comb_prob) + 1
})
}
} else {
# Explicands: get matrix of dimension n_obs x #index_features containing the indices of the active coalitions
comb_active_idx <- matrix(rep(index_features, times = n_obs), ncol = n_obs)
}

# Extract the active coalitions for each explicand. The number of rows are n_obs * n_comb_per_explicands,
# where the first n_comb_per_explicands rows are connected to the first explicand and so on. Set the column names.
id_comb <- as.vector(comb_active_idx)
comb_active <- S[id_comb, , drop = FALSE]
colnames(comb_active) <- names(feature_classes)

# Repeat the feature values as many times as there are active coalitions
x_augmented <- x[rep(seq_len(n_obs), each = regression.surrogate_n_comb), ]

# Mask the categorical features. Add a new level called "level_masked" when value is masked.
x_augmented[, (feature_cat) := lapply(seq_along(.SD), function(col) {
levels(.SD[[col]]) <- c(levels(.SD[[col]]), "level_masked")
.SD[[col]][comb_active[, feature_cat_idx[col]] == 0] <- "level_masked"
return(.SD[[col]])
}), .SDcols = feature_cat]

# Mask the continuous/non-categorical features
x_augmented[, (feature_cont) :=
lapply(seq_along(.SD), function(col) .SD[[col]] * comb_active[, feature_cont_idx[col]]),
.SDcols = feature_cont
]

# Add new columns indicating when the continuous features are masked
if (length(feature_cont) > 0) {
masked_columns <- paste0("mask_", feature_cont)
x_augmented <- cbind(x_augmented, setNames(data.table(1 * (comb_active[, feature_cont_idx] == 0)), masked_columns))
}

# Convert the binary masks to factor if user has specified so
if (augment_masks_as_factor) x_augmented[, (masked_columns) := lapply(.SD, as.factor), .SDcols = masked_columns]

# Add either uniform weights or Shapley kernel weights
if (!is.null(augment_weights)) {
x_augmented[, "weight" := if (augment_weights == "Shapley") internal$objects$X$shapley_weight[id_comb] else 1]
}

# Add the id_comb as a factor
if (augment_add_id_comb) x_augmented[, "id_comb" := factor(id_comb)]

# Add repeated responses if provided
if (!is.null(y_hat)) x_augmented[, "y_hat" := rep(y_hat, each = regression.surrogate_n_comb)]

# Return the augmented data
return(x_augmented)
}


# Check function =======================================================================================================
#' Check the `regression.surrogate_n_comb` parameter
#'
#' Check that `regression.surrogate_n_comb` is either NULL or a valid integer.
#'
#' @inheritParams setup_approach.regression_surrogate
#' @param used_n_combinations Integer. The number of used combinations (including the empty and grand coalitions).
#'
#' @author Lars Henry Berge Olsen
#' @keywords internal
regression.check_sur_n_comb <- function(regression.surrogate_n_comb, used_n_combinations) {
if (!is.null(regression.surrogate_n_comb)) {
if (regression.surrogate_n_comb < 1 || used_n_combinations - 2 < regression.surrogate_n_comb) {
stop(paste0(
"`regression.surrogate_n_comb` (", regression.surrogate_n_comb, ") must be a positive integer less than or ",
"equal to `used_n_combinations` minus two (", used_n_combinations - 2, ")."
))
}
}
}
2 changes: 1 addition & 1 deletion R/approach_vaeac.R
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ vaeac_check_mask_gen <- function(mask_gen_coalitions, mask_gen_coalitions_prob,
}
}

#' Function the checks the verbose parameter
#' Function that checks the verbose parameter
#'
#' @inheritParams vaeac_train_model
#'
Expand Down
Loading
Loading