Skip to content

Commit

Permalink
... (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
berndbischl authored Aug 24, 2024
1 parent 6b73434 commit da4d632
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 48 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* feat: Added new measure `mu_auc`.
* feat: Add option to calculate the mean of the true values on the train set in `msr("regr.rsq")`.
* feat: Default fallback learner is set when encapsulation is activated.
* feat: Learners classif.debug and regr.debug have new methods `$importance()` and `$selected_features()` for testing, also in downstream packages

# mlr3 0.20.2

Expand Down
32 changes: 30 additions & 2 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
},

#' @description
#' Returns 0 for each feature seen in training.
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
fns = self$state$feature_names
set_names(rep(0, length(fns)), fns)
},

#' @description
#' Always returns character(0).
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
character(0)
}
),
active = list(
Expand Down Expand Up @@ -169,8 +190,15 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
stopf("Early stopping is only possible when a validation task is present.")
}

model = list(response = as.character(sample(task$truth(), 1L)), pid = Sys.getpid(), id = UUIDgenerate(),
random_number = sample(100000, 1), iter = if (isTRUE(pv$early_stopping)) sample(pv$iter %??% 1L, 1L) else pv$iter %??% 1L
model = list(
response = as.character(sample(task$truth(), 1L)),
pid = Sys.getpid(),
id = UUIDgenerate(),
random_number = sample(100000, 1),
iter = if (isTRUE(pv$early_stopping))
sample(pv$iter %??% 1L, 1L)
else
pv$iter %??% 1L
)

if (!is.null(valid_truth)) {
Expand Down
21 changes: 21 additions & 0 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
man = "mlr3::mlr_learners_regr.debug",
label = "Debug Learner for Regression"
)
},

#' @description
#' Returns 0 for each feature seen in training.
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
fns = self$state$feature_names
set_names(rep(0, length(fns)), fns)
},

#' @description
#' Always returns character(0).
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
character(0)
}
),
private = list(
Expand Down
12 changes: 1 addition & 11 deletions man/Measure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 1 addition & 11 deletions man/MeasureClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 1 addition & 11 deletions man/MeasureRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 1 addition & 11 deletions man/MeasureSimilarity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions man/mlr_learners_classif.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions man/mlr_learners_regr.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/mlr_measures_regr.pinball.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/test_mlr_learners_classif_debug.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ test_that("marshaling", {
p2 = l$marshal()$unmarshal()$predict(task)
expect_equal(p1, p2)
})

test_that("importance and selected features", {
l = lrn("classif.debug")
task = tsk("iris")
l$train(task)
expect_equal(l$selected_features(), character(0))
expect_equal(l$importance(), set_names(rep(0, task$n_features), task$feature_names))
})


25 changes: 25 additions & 0 deletions tests/testthat/test_mlr_learners_regr_debug.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# this test / files was missing, only classif.debug was unit-tested
# I added at least a few basic tests when i added methods "importance" and "selected_features"

test_that("Simple training/predict", {
task = tsk("mtcars")
learner = lrn("regr.debug")
expect_learner(learner, task)

prediction = learner$train(task)$predict(task)
expect_class(learner$model, "regr.debug_model")
expect_numeric(learner$model$response, len = 1L, any.missing = FALSE)
expect_numeric(prediction$response, any.missing = FALSE)
})


test_that("importance and selected features", {
l = lrn("regr.debug")
task = tsk("mtcars")
l$train(task)
expect_equal(l$selected_features(), character(0))
expect_equal(l$importance(), set_names(rep(0, task$n_features), task$feature_names))
})



0 comments on commit da4d632

Please sign in to comment.