diff --git a/NEWS.md b/NEWS.md index 1df88197..7d888ef7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,7 @@ Minor spelling / mathematical updates to Scoring rule vignette. (#969) - A bug was fixed where `crps_sample()` could fail in edge cases. - Implemented a new forecast class, `forecast_ordinal` with appropriate metrics. Ordinal forecasts are a form of categorical forecasts. The main difference between ordinal and nominal forecasts is that the outcome is ordered, rather than unordered. +- Refactored the way that columns get internally renamed in `as_forecast_()` functions (#980) # scoringutils 2.0.0 diff --git a/R/class-forecast-binary.R b/R/class-forecast-binary.R index 2f1b3eb0..09fc1704 100644 --- a/R/class-forecast-binary.R +++ b/R/class-forecast-binary.R @@ -34,7 +34,12 @@ as_forecast_binary <- function(data, forecast_unit = NULL, observed = NULL, predicted = NULL) { - data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- as_forecast_generic( + data, + forecast_unit, + observed = observed, + predicted = predicted + ) data <- new_forecast(data, "forecast_binary") assert_forecast(data) return(data) diff --git a/R/class-forecast-nominal.R b/R/class-forecast-nominal.R index a457d964..c589b585 100644 --- a/R/class-forecast-nominal.R +++ b/R/class-forecast-nominal.R @@ -45,13 +45,13 @@ as_forecast_nominal <- function(data, observed = NULL, predicted = NULL, predicted_label = NULL) { - assert_character(predicted_label, len = 1, null.ok = TRUE) - assert_subset(predicted_label, names(data), empty.ok = TRUE) - if (!is.null(predicted_label)) { - setnames(data, old = predicted_label, new = "predicted_label") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- as_forecast_generic( + data, + forecast_unit, + observed = observed, + predicted = predicted, + predicted_label = predicted_label + ) data <- new_forecast(data, "forecast_nominal") assert_forecast(data) return(data) diff --git a/R/class-forecast-ordinal.R b/R/class-forecast-ordinal.R index 0926cd5d..d6dda020 100644 --- a/R/class-forecast-ordinal.R +++ b/R/class-forecast-ordinal.R @@ -45,13 +45,13 @@ as_forecast_ordinal <- function(data, observed = NULL, predicted = NULL, predicted_label = NULL) { - assert_character(predicted_label, len = 1, null.ok = TRUE) - assert_subset(predicted_label, names(data), empty.ok = TRUE) - if (!is.null(predicted_label)) { - setnames(data, old = predicted_label, new = "predicted_label") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- as_forecast_generic( + data, + forecast_unit, + observed = observed, + predicted = predicted, + predicted_label = predicted_label + ) data <- new_forecast(data, "forecast_ordinal") assert_forecast(data) return(data) diff --git a/R/class-forecast-point.R b/R/class-forecast-point.R index 0dfc87ab..7e2aa9eb 100644 --- a/R/class-forecast-point.R +++ b/R/class-forecast-point.R @@ -30,7 +30,12 @@ as_forecast_point.default <- function(data, observed = NULL, predicted = NULL, ...) { - data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- as_forecast_generic( + data, + forecast_unit, + observed = observed, + predicted = predicted + ) data <- new_forecast(data, "forecast_point") assert_forecast(data) return(data) diff --git a/R/class-forecast-quantile.R b/R/class-forecast-quantile.R index caf301fa..ac341ee2 100644 --- a/R/class-forecast-quantile.R +++ b/R/class-forecast-quantile.R @@ -46,13 +46,13 @@ as_forecast_quantile.default <- function(data, predicted = NULL, quantile_level = NULL, ...) { - assert_character(quantile_level, len = 1, null.ok = TRUE) - assert_subset(quantile_level, names(data), empty.ok = TRUE) - if (!is.null(quantile_level)) { - setnames(data, old = quantile_level, new = "quantile_level") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- as_forecast_generic( + data, + forecast_unit, + observed = observed, + predicted = predicted, + quantile_level = quantile_level + ) data <- new_forecast(data, "forecast_quantile") assert_forecast(data) return(data) diff --git a/R/class-forecast-sample.R b/R/class-forecast-sample.R index 352c139c..8252ca4d 100644 --- a/R/class-forecast-sample.R +++ b/R/class-forecast-sample.R @@ -29,13 +29,13 @@ as_forecast_sample <- function(data, observed = NULL, predicted = NULL, sample_id = NULL) { - assert_character(sample_id, len = 1, null.ok = TRUE) - assert_subset(sample_id, names(data), empty.ok = TRUE) - if (!is.null(sample_id)) { - setnames(data, old = sample_id, new = "sample_id") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- as_forecast_generic( + data, + forecast_unit, + observed = observed, + predicted = predicted, + sample_id = sample_id + ) data <- new_forecast(data, "forecast_sample") assert_forecast(data) return(data) diff --git a/R/class-forecast.R b/R/class-forecast.R index 09a279df..697e94fa 100644 --- a/R/class-forecast.R +++ b/R/class-forecast.R @@ -4,25 +4,29 @@ #' It renames the required columns, where appropriate, and sets the forecast #' unit. #' @inheritParams as_forecast_doc_template +#' @param ... Named arguments that are used to rename columns. The names of the +#' arguments are the names of the columns that should be renamed. The values +#' are the new names. #' @keywords as_forecast as_forecast_generic <- function(data, forecast_unit = NULL, - observed = NULL, - predicted = NULL) { - # check inputs - general + ...) { data <- ensure_data.table(data) - assert_character(observed, len = 1, null.ok = TRUE) - assert_subset(observed, names(data), empty.ok = TRUE) - - assert_character(predicted, len = 1, null.ok = TRUE) - assert_subset(predicted, names(data), empty.ok = TRUE) - - # rename columns - general - if (!is.null(observed)) { - setnames(data, old = observed, new = "observed") - } - if (!is.null(predicted)) { - setnames(data, old = predicted, new = "predicted") + oldnames <- list(...) + newnames <- names(oldnames) + provided <- !sapply(oldnames, is.null) + + lapply(seq_along(oldnames), function(i) { + var <- oldnames[[i]] + varname <- names(oldnames)[i] + assert_character(var, len = 1, null.ok = TRUE, .var.name = varname) + assert_subset(var, names(data), empty.ok = TRUE, .var.name = varname) + }) + + oldnames <- unlist(oldnames[provided]) + newnames <- unlist(newnames[provided]) + if (!is.null(oldnames) && length(oldnames) > 0) { + setnames(data, old = oldnames, new = newnames) } # set forecast unit (error handling is done in `set_forecast_unit()`) diff --git a/man/as_forecast_generic.Rd b/man/as_forecast_generic.Rd index 9cb52e5b..cfbba7c2 100644 --- a/man/as_forecast_generic.Rd +++ b/man/as_forecast_generic.Rd @@ -4,12 +4,7 @@ \alias{as_forecast_generic} \title{Common functionality for \verb{as_forecast_} functions} \usage{ -as_forecast_generic( - data, - forecast_unit = NULL, - observed = NULL, - predicted = NULL -) +as_forecast_generic(data, forecast_unit = NULL, ...) } \arguments{ \item{data}{A data.frame (or similar) with predicted and observed values. @@ -23,11 +18,9 @@ If \code{NULL} (the default), all columns that are not required columns are assumed to form the unit of a single forecast. If specified, all columns that are not part of the forecast unit (or required columns) will be removed.} -\item{observed}{(optional) Name of the column in \code{data} that contains the -observed values. This column will be renamed to "observed".} - -\item{predicted}{(optional) Name of the column in \code{data} that contains the -predicted values. This column will be renamed to "predicted".} +\item{...}{Named arguments that are used to rename columns. The names of the +arguments are the names of the columns that should be renamed. The values +are the new names.} } \description{ Common functionality for \verb{as_forecast_} functions diff --git a/man/assert_input_categorical.Rd b/man/assert_input_categorical.Rd new file mode 100644 index 00000000..cc203e8f --- /dev/null +++ b/man/assert_input_categorical.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metrics-nominal.R +\name{assert_input_categorical} +\alias{assert_input_categorical} +\title{Assert that inputs are correct for categorical forecasts} +\usage{ +assert_input_categorical(observed, predicted, predicted_label, ordered = NA) +} +\arguments{ +\item{observed}{Input to be checked. Should be a factor of length n with +N levels holding the observed values. n is the number of observations and +N is the number of possible outcomes the observed values can assume.} + +\item{predicted}{Input to be checked. Should be nxN matrix of predicted +probabilities, n (number of rows) being the number of data points and N +(number of columns) the number of possible outcomes the observed values +can assume. +If \code{observed} is just a single number, then predicted can just be a +vector of size N. +Values represent the probability that the corresponding value +in \code{observed} will be equal to the factor level referenced in +\code{predicted_label}.} + +\item{predicted_label}{Factor of length N with N levels, where N is the +number of possible outcomes the observed values can assume.} + +\item{ordered}{Value indicating whether factors have to be ordered or not. +Defaults to \code{NA}, which means that the check is not performed.} +} +\value{ +Returns NULL invisibly if the assertion was successful and throws an +error otherwise. +} +\description{ +Function assesses whether the inputs correspond to the +requirements for scoring categorical, i.e. either nominal or ordinal +forecasts. +} +\keyword{internal_input_check} diff --git a/scoringutils.Rproj b/scoringutils.Rproj index a96991f9..bb245469 100644 --- a/scoringutils.Rproj +++ b/scoringutils.Rproj @@ -1,4 +1,5 @@ Version: 1.0 +ProjectId: 008f911e-df6e-4218-825c-db1095ac43c4 RestoreWorkspace: No SaveWorkspace: No diff --git a/tests/testthat/test-class-forecast-quantile.R b/tests/testthat/test-class-forecast-quantile.R index 3711cf05..fda26428 100644 --- a/tests/testthat/test-class-forecast-quantile.R +++ b/tests/testthat/test-class-forecast-quantile.R @@ -157,6 +157,25 @@ test_that("as_forecast_quantiles issue 557 fix", { expect_equal(any(is.na(out$interval_coverage_deviation)), FALSE) }) +test_that("as_forecast_quantile doesn't modify column names in place", { + quantile_data <- data.table( + my_quantile = c(0.25, 0.5), + forecast_value = c(1, 2), + observed_value = c(5, 5) + ) + pre <- names(quantile_data) + + quantile_forecast <- quantile_data %>% + as_forecast_quantile( + predicted = "forecast_value", + observed = "observed_value", + quantile_level = "my_quantile" + ) + + post <- names(quantile_data) + expect_equal(pre, post) +}) + # ============================================================================== # is_forecast_quantile()