Skip to content

Commit

Permalink
refactor functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Dec 7, 2024
1 parent a27b87a commit 0219d10
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 260 deletions.
35 changes: 18 additions & 17 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Authors@R: c(
family = "Abbott",
role = c("aut"),
email = "[email protected]",
comment = c(ORCID = "0000-0001-8057-8037")),
comment = c(ORCID = "0000-0001-8057-8037")),
person(given = "Hugo",
family = "Gruson",
role = c("aut"),
Expand All @@ -22,7 +22,7 @@ Authors@R: c(
family = "Bracher",
role = c("ctb"),
email = "[email protected]",
comment = c(ORCID = "0000-0002-3777-1410")),
comment = c(ORCID = "0000-0002-3777-1410")),
person(given = "Toshiaki Asakura",
role = c("ctb"),
email = "[email protected]",
Expand All @@ -32,53 +32,54 @@ Authors@R: c(
role = c("ctb"),
email = "[email protected]",
comment = c(ORCID = "0000-0001-5782-7330")),
person("Sebastian", "Funk",
email = "[email protected]",
person("Sebastian", "Funk",
email = "[email protected]",
role = c("aut")),
person(given = "Michael",
family = "Chirico",
role = c("ctb"),
email = "[email protected]",
comment = c(ORCID = "0000-0003-0787-087X")))
Description:
Facilitate the evaluation of forecasts in a convenient
framework based on data.table. It allows user to to check their forecasts
and diagnose issues, to visualise forecasts and missing data, to transform
data before scoring, to handle missing forecasts, to aggregate scores, and
to visualise the results of the evaluation. The package mostly focuses on
the evaluation of probabilistic forecasts and allows evaluating several
different forecast types and input formats. Find more information about the
package in the Vignettes as well as in the accompanying paper,
Description:
Facilitate the evaluation of forecasts in a convenient
framework based on data.table. It allows user to to check their forecasts
and diagnose issues, to visualise forecasts and missing data, to transform
data before scoring, to handle missing forecasts, to aggregate scores, and
to visualise the results of the evaluation. The package mostly focuses on
the evaluation of probabilistic forecasts and allows evaluating several
different forecast types and input formats. Find more information about the
package in the Vignettes as well as in the accompanying paper,
<doi:10.48550/arXiv.2205.07090>.
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
Imports:
Imports:
checkmate,
cli,
data.table,
ggplot2 (>= 3.4.0),
lifecycle,
methods,
Metrics,
purrr,
scoringRules,
stats
Suggests:
Suggests:
ggdist,
kableExtra,
knitr,
magrittr,
rmarkdown,
testthat (>= 3.1.9),
vdiffr
Config/Needs/website:
Config/Needs/website:
r-lib/pkgdown,
amirmasoudabdol/preferably
Config/testthat/edition: 3
RoxygenNote: 7.3.2
URL: https://doi.org/10.48550/arXiv.2205.07090, https://epiforecasts.io/scoringutils/, https://github.com/epiforecasts/scoringutils
BugReports: https://github.com/epiforecasts/scoringutils/issues
VignetteBuilder: knitr
Depends:
Depends:
R (>= 4.0)
Roxygen: list(markdown = TRUE)
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ export(is_forecast_quantile)
export(is_forecast_sample)
export(log_shift)
export(logs_binary)
export(logs_nominal)
export(logs_ordinal)
export(logs_categorical)
export(logs_sample)
export(mad_sample)
export(new_forecast)
Expand Down Expand Up @@ -185,6 +184,7 @@ importFrom(ggplot2,theme_minimal)
importFrom(ggplot2,unit)
importFrom(ggplot2,xlab)
importFrom(ggplot2,ylab)
importFrom(lifecycle,deprecate_warn)
importFrom(methods,hasArg)
importFrom(purrr,partial)
importFrom(scoringRules,crps_sample)
Expand Down
4 changes: 2 additions & 2 deletions R/class-forecast-nominal.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ score.forecast_nominal <- function(forecast, metrics = get_metrics(forecast), ..
#' @inheritParams get_metrics.forecast_binary
#' @description
#' For nominal forecasts, the default scoring rule is:
#' - "log_score" = [logs_nominal()]
#' - "log_score" = [logs_categorical()]
#' @export
#' @family get_metrics functions
#' @keywords handle-metrics
#' @examples
#' get_metrics(example_nominal)
get_metrics.forecast_nominal <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
log_score = logs_nominal
log_score = logs_categorical
)
select_metrics(all, select, exclude)
}
Expand Down
4 changes: 2 additions & 2 deletions R/class-forecast-ordinal.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ score.forecast_ordinal <- function(forecast, metrics = get_metrics(forecast), ..
#' @inheritParams get_metrics.forecast_binary
#' @description
#' For ordinal forecasts, the default scoring rules are:
#' - "log_score" = [logs_nominal()]
#' - "log_score" = [logs_categorical()]
#' - "rps" = [rps_ordinal()]
#' @export
#' @family get_metrics functions
Expand All @@ -160,7 +160,7 @@ score.forecast_ordinal <- function(forecast, metrics = get_metrics(forecast), ..
#' get_metrics(example_ordinal)
get_metrics.forecast_ordinal <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
log_score = logs_nominal,
log_score = logs_categorical,
rps = rps_ordinal
)
select_metrics(all, select, exclude)
Expand Down
74 changes: 57 additions & 17 deletions R/metrics-nominal.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
#' @title Assert that inputs are correct for nominal forecasts
#' @description Function assesses whether the inputs correspond to the
#' requirements for scoring nominal forecasts.
#' @param observed Input to be checked. Should be an unordered factor of length
#' n with N levels holding the observed values. n is the number of
#' observations and N is the number of possible outcomes the observed values
#' can assume.
#' @param predicted_label Unordered factor of length N with N levels, where N
#' is the number of possible outcomes the observed values can assume.
#' @inheritParams assert_input_categorical
#' @importFrom checkmate assert_factor assert_numeric assert_set_equal
#' @inherit document_assert_functions return
#' @keywords internal_input_check
assert_input_nominal <- function(observed, predicted, predicted_label) {
assert_input_categorical(
observed, predicted, predicted_label, ordered = FALSE
)
return(invisible(NULL))
}


#' @title Assert that inputs are correct for categorical forecasts
#' @description Function assesses whether the inputs correspond to the
#' requirements for scoring categorical, i.e. either nominal or ordinal
#' forecasts.
#' @param observed Input to be checked. Should be a factor of length n with
#' N levels holding the observed values. n is the number of observations and
#' N is the number of possible outcomes the observed values can assume.
Expand All @@ -11,23 +33,29 @@
#' If `observed` is just a single number, then predicted can just be a
#' vector of size N.
#' Values represent the probability that the corresponding value
#' in `observed` will be equal to the highest available factor level.
#' in `observed` will be equal to the factor level referenced in
#' `predicted_label`.
#' @param predicted_label Factor of length N with N levels, where N is the
#' number of possible outcomes the observed values can assume.
#' @param ordered Value indicating whether factors have to be ordered or not.
#' Defaults to `NA`, which means that the check is not performed.
#' @importFrom checkmate assert_factor assert_numeric assert_set_equal
#' @inherit document_assert_functions return
#' @keywords internal_input_check
assert_input_nominal <- function(observed, predicted, predicted_label) {
assert_input_categorical <- function(
observed, predicted, predicted_label, ordered = NA
) {
# observed
assert_factor(observed, min.len = 1, min.levels = 2)
assert_factor(observed, min.len = 1, min.levels = 2, ordered = ordered)
levels <- levels(observed)
n <- length(observed)
N <- length(levels)

# predicted label
assert_factor(
predicted_label, len = N,
any.missing = FALSE, empty.levels.ok = FALSE
any.missing = FALSE, empty.levels.ok = FALSE,
ordered = ordered
)
assert_set_equal(levels(observed), levels(predicted_label))

Expand Down Expand Up @@ -59,21 +87,22 @@ assert_input_nominal <- function(observed, predicted, predicted_label) {
}


#' Log score for nominal outcomes
#' Log score for categorical outcomes
#'
#' @description
#' **Log score for nominal outcomes**
#' **Log score for categorical (nominal or ordinal) outcomes**
#'
#' The Log Score is the negative logarithm of the probability
#' assigned to the observed value. It is a proper scoring rule. Small values
#' are better (best is zero, worst is infinity).
#' @param observed A factor of length n with N levels holding the observed
#' values.
#'
#' @param observed Factor of length n with N levels holding the
#' observed values.
#' @param predicted nxN matrix of predictive probabilities, n (number of rows)
#' being the number of observations and N (number of columns) the number of
#' possible outcomes.
#' @param predicted_label A factor of length N, denoting the outcome that the
#' probabilities in `predicted` correspond to.
#' @param predicted_label Factor of length N, denoting the outcome
#' that the probabilities in `predicted` correspond to.
#' @returns A numeric vector of size n with log scores
#' @inheritSection illustration-input-metric-nominal Input format
#' @importFrom methods hasArg
Expand All @@ -85,13 +114,16 @@ assert_input_nominal <- function(observed, predicted, predicted_label) {
#' factor_levels <- c("one", "two", "three")
#' predicted_label <- factor(c("one", "two", "three"), levels = factor_levels)
#' observed <- factor(c("one", "three", "two"), levels = factor_levels)
#' predicted <- matrix(c(0.8, 0.1, 0.4,
#' 0.1, 0.2, 0.4,
#' 0.1, 0.7, 0.2),
#' nrow = 3)
#' logs_nominal(observed, predicted, predicted_label)
logs_nominal <- function(observed, predicted, predicted_label) {
assert_input_nominal(observed, predicted, predicted_label)
#' predicted <- matrix(
#' c(0.8, 0.1, 0.1,
#' 0.1, 0.2, 0.7,
#' 0.4, 0.4, 0.2),
#' nrow = 3,
#' byrow = TRUE
#' )
#' logs_categorical(observed, predicted, predicted_label)
logs_categorical <- function(observed, predicted, predicted_label) {
assert_input_categorical(observed, predicted, predicted_label)
n <- length(observed)
if (n == 1) {
predicted <- matrix(predicted, nrow = 1)
Expand All @@ -101,3 +133,11 @@ logs_nominal <- function(observed, predicted, predicted_label) {
logs <- -log(pred_for_observed)
return(logs)
}

#' @importFrom lifecycle deprecate_warn
logs_nominal <- function(observed, predicted, predicted_label) {
deprecate_warn(
when = "2.1.0", what = "logs_nominal()", with = "logs_categorical()"
)
logs_categorical(observed, predicted, predicted_label)
}
82 changes: 1 addition & 81 deletions R/metrics-ordinal.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,99 +18,19 @@
#' @inherit document_assert_functions return
#' @keywords internal_input_check
assert_input_ordinal <- function(observed, predicted, predicted_label) {
# observed
assert_factor(observed, min.len = 1, min.levels = 2, ordered = TRUE)
levels <- levels(observed)
n <- length(observed)
N <- length(levels)
assert_input_categorical(observed, predicted, predicted_label, ordered = TRUE)

# predicted label
assert_factor(
predicted_label,
len = N,
any.missing = FALSE, empty.levels.ok = FALSE, ordered = TRUE
)
if (!identical(levels(predicted_label), levels(observed))) {
cli_abort(
"Levels of `predicted_label` and `observed` must be identical
and in the same order. Found levels {.val {levels(predicted_label)}}
and {.val {levels(observed)}}."
)
}

# predicted
assert_numeric(predicted, min.len = 1, lower = 0, upper = 1)
if (n == 1) {
assert(
# allow one of two options
check_vector(predicted, len = N),
check_matrix(predicted, nrows = n, ncols = N)
)
summed_predictions <- .rowSums(predicted, m = 1, n = N, na.rm = TRUE)
} else {
assert_matrix(predicted, nrows = n)
summed_predictions <- round(rowSums(predicted, na.rm = TRUE), 10) # avoid numeric errors
}
if (!all(summed_predictions == 1)) {
#nolint start: keyword_quote_linter object_usage_linter
row_indices <- as.character(which(summed_predictions != 1))
cli_abort(
c(
`!` = "Probabilities belonging to a single forecast must sum to one",
`i` = "Found issues in row{?s} {row_indices} of {.var predicted}"
)
)
#nolint end
}
return(invisible(NULL))
}


#' Log score for ordinal outcomes
#'
#' @description
#' **Log score for ordinal outcomes**
#'
#' The Log Score is the negative logarithm of the probability
#' assigned to the observed value. It is a proper scoring rule. Small values
#' are better (best is zero, worst is infinity).
#' @param observed A factor of length n with N levels holding the observed
#' values.
#' @param predicted nxN matrix of predictive probabilities, n (number of rows)
#' being the number of observations and N (number of columns) the number of
#' possible outcomes. If `observed` is just a single number, then predicted
#' can just be a vector of size N.
#' Values represent the probability that the corresponding value in `observed`
#' will be equal to factor level referenced in `predicted_label`.
#' @param predicted_label A factor of length N, denoting the outcome that the
#' probabilities in `predicted` correspond to.
#' @returns A numeric vector of size n with log scores
#' @inheritSection illustration-input-metric-ordinal Input format
#' @importFrom methods hasArg
#' @export
#' @keywords metric
#' @family log score functions
#' @examples
#' factor_levels <- c("one", "two", "three")
#' predicted_label <- factor(c("one", "two", "three"), levels = factor_levels)
#' observed <- factor(c("one", "three", "two"), levels = factor_levels)
#' predicted <- matrix(c(0.8, 0.1, 0.4,
#' 0.1, 0.2, 0.4,
#' 0.1, 0.7, 0.2),
#' nrow = 3)
#' logs_nominal(observed, predicted, predicted_label)
logs_ordinal <- function(observed, predicted, predicted_label) {
assert_input_ordinal(observed, predicted, predicted_label)
n <- length(observed)
if (n == 1) {
predicted <- matrix(predicted, nrow = 1)
}
observed_indices <- as.numeric(observed)
pred_for_observed <- predicted[cbind(1:n, observed_indices)]
logs <- -log(pred_for_observed)
return(logs)
}

#' Ranked Probability Score for ordinal outcomes
#'
#' @description
Expand Down
14 changes: 8 additions & 6 deletions man/assert_input_nominal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0219d10

Please sign in to comment.