Skip to content

Commit

Permalink
feat: throw warning when prediction and measure type do not match (#1188
Browse files Browse the repository at this point in the history
)

* feat: throw warning when prediction and measure type do not match

* ...
  • Loading branch information
be-marc authored Oct 18, 2024
1 parent e22dbe4 commit 2487013
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 4 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3 (development version)

* feat: Throw warning when prediction and measure type do not match.
* fix: The `mlr_reflections` were broken when an extension package was not loaded on the workers.
Extension packages must now register themselves in the `mlr_reflections$loaded_packages` field.

Expand Down
2 changes: 1 addition & 1 deletion R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Measure = R6Class("Measure",
#'
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_measure(self, task = task, learner = learner)
assert_measure(self, task = task, learner = learner, prediction = prediction)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)

if ("requires_task" %in% self$properties && is.null(task)) {
Expand Down
10 changes: 9 additions & 1 deletion R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,9 @@ assert_predictable = function(task, learner) {

#' @export
#' @param measure ([Measure]).
#' @param prediction ([Prediction]).
#' @rdname mlr_assertions
assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vname(measure)) {
assert_measure = function(measure, task = NULL, learner = NULL, prediction = NULL, .var.name = vname(measure)) {
assert_class(measure, "Measure", .var.name = .var.name)

if (!is.null(task)) {
Expand Down Expand Up @@ -236,6 +237,13 @@ assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vnam
}
}

if (!is.null(prediction)) {
# same as above but works without learner e.g. measure$score(prediction)
if (measure$check_prerequisites != "ignore" && measure$predict_type %nin% prediction$predict_types) {
warningf("Measure '%s' is missing predict type '%s' of prediction", measure$id, measure$predict_type)
}
}

invisible(measure)
}

Expand Down
5 changes: 3 additions & 2 deletions man/mlr_assertions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/test_Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,13 @@ test_that("checks on predict_sets", {
expect_error({m$predict_sets = "imaginary"}, "Must be a subset")
})

test_that("measure and prediction type is checked", {
learner = lrn("classif.rpart")
task = tsk("pima")
learner$train(task)
pred = learner$predict(task)

measure = msr("classif.logloss")
expect_warning(measure$score(pred), "is missing predict type")
})

0 comments on commit 2487013

Please sign in to comment.