Skip to content

Commit

Permalink
refactor: rename lag arg to lags
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 25, 2025
1 parent 2a037f2 commit aff5f76
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 59 deletions.
20 changes: 10 additions & 10 deletions R/ForecastLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ ForecastLearner = R6::R6Class("ForecastLearner",
#' The learner
learner = NULL,

#' @field lag (`integer()`)\cr
#' The lag
lag = NULL,
#' @field lags (`integer()`)\cr
#' The lags
lags = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param task ([Task])\cr
#' @param learner ([Learner])\cr
#' @param lag (`integer(1)`)\cr
initialize = function(learner, lag) {
#' @param lags (`integer(1)`)\cr
initialize = function(learner, lags) {
self$learner = assert_learner(as_learner(learner, clone = TRUE))
self$lag = assert_integerish(lag, lower = 1L, any.missing = FALSE, coerce = TRUE)
self$lags = assert_integerish(lags, lower = 1L, any.missing = FALSE, coerce = TRUE)

super$initialize(
id = learner$id,
Expand Down Expand Up @@ -80,14 +80,14 @@ ForecastLearner = R6::R6Class("ForecastLearner",
},

.lag_transform = function(dt, target) {
lag = self$lag
nms = sprintf("%s_lag_%i", target, lag)
lags = self$lags
nms = sprintf("%s_lag_%i", target, lags)
dt = copy(dt)
key_cols = private$.task$col_roles$key
if (length(key_cols) > 0L) {
dt[, (nms) := shift(.SD, lag), by = key_cols, .SDcols = target]
dt[, (nms) := shift(.SD, lags), by = key_cols, .SDcols = target]
} else {
dt[, (nms) := shift(.SD, lag), .SDcols = target]
dt[, (nms) := shift(.SD, lags), .SDcols = target]
}
dt
},
Expand Down
32 changes: 16 additions & 16 deletions R/PipeOpFcstLag.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @section Parameters:
#' The parameters are the parameters inherited from [mlr3pipelines::PipeOpTaskPreproc],
#' as well as the following parameters:
#' * `lag` :: `integer()`\cr
#' * `lags` :: `integer()`\cr
#' The lags to create.
#'
#' @export
Expand All @@ -24,9 +24,9 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
#' otherwise be set during construction. Default `list()`.
initialize = function(id = "fcst.lags", param_vals = list()) {
param_set = ps(
lag = p_uty(tags = c("train", "predict"), custom_check = check_integerish)
lags = p_uty(tags = c("train", "predict"), custom_check = check_integerish)
)
param_set$set_values(lag = 1L)
param_set$set_values(lags = 1L)

super$initialize(
id = id,
Expand All @@ -41,37 +41,37 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
private = list(
.train_task = function(task) {
pv = self$param_set$get_values(tags = "train")
lag = pv$lag
lags = pv$lags
target = task$target_names
key_cols = task$col_roles$key
order_cols = task$col_roles$order
dt = task$data()
self$state = list(dt = dt[(.N - max(lag)):.N])
nms = sprintf("%s_lag_%i", target, lag)
self$state = list(dt = dt[(.N - max(lags)):.N])
nms = sprintf("%s_lag_%i", target, lags)
if (length(key_cols) > 0L) {
setorderv(dt, c(key_cols, order_cols))
dt[, (nms) := shift(.SD, lag), by = key_cols, .SDcols = target]
dt[, (nms) := shift(.SD, lags), by = key_cols, .SDcols = target]
} else {
setorderv(dt, order_cols)
dt[, (nms) := shift(.SD, lag), .SDcols = target]
dt[, (nms) := shift(.SD, lags), .SDcols = target]
}
task$select(task$feature_names)$cbind(dt)
},

.predict_task = function(task) {
pv = self$param_set$get_values(tags = "predict")
lag = pv$lag
lags = pv$lags
target = task$target_names
key_cols = task$col_roles$key
order_cols = task$col_roles$order
dt = rbind(self$state$dt, task$data())
nms = sprintf("%s_lag_%i", target, lag)
nms = sprintf("%s_lag_%i", target, lags)
if (length(key_cols) > 0L) {
setorderv(dt, c(key_cols, order_cols))
dt[, (nms) := shift(.SD, lag), by = key_cols, .SDcols = target]
dt[, (nms) := shift(.SD, lags), by = key_cols, .SDcols = target]
} else {
setorderv(dt, order_cols)
dt[, (nms) := shift(.SD, lag), .SDcols = target]
dt[, (nms) := shift(.SD, lags), .SDcols = target]
}
dt = dt[(.N - task$nrow + 1L):.N]
task$select(task$feature_names)$cbind(dt)
Expand Down Expand Up @@ -110,10 +110,10 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
# this wouldn't allow sorting since we don't get the task here,
# as well as getting the target name
pv = self$param_set$get_values()
lag = pv$lag
nms = sprintf("target_lag_%i", lag)
lags = pv$lags
nms = sprintf("target_lag_%i", lags)
dt[, target := target]
dt[, (nms) := shift(.SD, lag), .SDcols = "target"]
dt[, (nms) := shift(.SD, lags), .SDcols = "target"]
dt[, target := NULL]
dt
},
Expand All @@ -125,4 +125,4 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
)

#' @include zzz.R
register_po("fcst.lag", PipeOpFcstLag)
register_po("fcst.lags", PipeOpFcstLag)
10 changes: 5 additions & 5 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ library(mlr3learners)
library(mlr3pipelines)
task = tsk("airpassengers")
pop = po("fcst.lag", lag = 1:12)
pop = po("fcst.lag", lags = 1:12)
new_task = pop$train(list(task))[[1L]]
new_task$data()
task = tsk("airpassengers")
graph = po("fcst.lag", lag = 1:12) %>>%
graph = po("fcst.lag", lags = 1:12) %>>%
ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
Expand Down Expand Up @@ -295,7 +295,7 @@ trafo = po("targetmutate",
)
)
graph = po("fcst.lag", lag = 1:12) %>>%
graph = po("fcst.lag", lags = 1:12) %>>%
ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
Expand All @@ -315,7 +315,7 @@ prediction$score(msr("regr.rmse"))
```

```{r, eval = FALSE}
graph = po("fcst.lag", lag = 1:12) %>>%
graph = po("fcst.lag", lags = 1:12) %>>%
ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
Expand All @@ -328,7 +328,7 @@ graph = po("fcst.lag", lag = 1:12) %>>%
task = tsk("airpassengers")
flrn = ForecastRecursiveLearner$new(lrn("regr.ranger"))
glrn = as_learner(graph %>>% flrn)
trafo = po("fcst.targetdiff", lag = 12L)
trafo = po("fcst.targetdiff", lags = 12L)
pipeline = ppl("targettrafo", graph = glrn, trafo_pipeop = trafo)
glrn = as_learner(pipeline)$train(task)
prediction = glrn$predict(task, 142:144)
Expand Down
46 changes: 23 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,32 +99,32 @@ prediction = flrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 436.1867
#> 2 NA 437.4089
#> 3 NA 456.8410
#> 1 NA 438.6738
#> 2 NA 438.2207
#> 3 NA 457.2237
prediction = flrn$predict(task, 142:144)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 461 459.1495
#> 2 390 414.8433
#> 3 432 430.2693
#> 1 461 456.8032
#> 2 390 412.9617
#> 3 432 432.0672
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 14.41767
#> 13.4766

flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)
resampling = rsmp("forecast_holdout", ratio = 0.9)
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 48.97126
#> 48.4789

resampling = rsmp("forecast_cv")
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 25.19211
#> 25.08963
```

Or with some feature engineering using mlr3pipelines:
Expand All @@ -146,7 +146,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 15.58057
#> 14.22429
```

### Example: forecasting electricity demand
Expand Down Expand Up @@ -176,13 +176,13 @@ prediction = glrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 14 observations:
#> row_ids truth response
#> 1 NA 187595.7
#> 2 NA 196608.6
#> 3 NA 189152.0
#> 1 NA 189375.9
#> 2 NA 199550.0
#> 3 NA 188647.1
#> --- --- ---
#> 12 NA 222400.3
#> 13 NA 226494.8
#> 14 NA 226568.4
#> 12 NA 221192.0
#> 13 NA 225456.5
#> 14 NA 227090.1
```

### Example: global forecasting (longitudinal data)
Expand Down Expand Up @@ -220,14 +220,14 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
prediction = flrn$predict(task, 4460:4464)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 22604.48
#> 22055.26

flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.9)
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 92125.26
#> 92992
```

### Example: global vs local forecasting
Expand Down Expand Up @@ -293,12 +293,12 @@ library(mlr3learners)
library(mlr3pipelines)

task = tsk("airpassengers")
pop = po("fcst.lag", lag = 1:12)
pop = po("fcst.lag", lags = 1:12)
new_task = pop$train(list(task))[[1L]]
new_task$data()

task = tsk("airpassengers")
graph = po("fcst.lag", lag = 1:12) %>>%
graph = po("fcst.lag", lags = 1:12) %>>%
ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
Expand Down Expand Up @@ -338,7 +338,7 @@ trafo = po("targetmutate",
)
)

graph = po("fcst.lag", lag = 1:12) %>>%
graph = po("fcst.lag", lags = 1:12) %>>%
ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
Expand All @@ -358,7 +358,7 @@ prediction$score(msr("regr.rmse"))
```

``` r
graph = po("fcst.lag", lag = 1:12) %>>%
graph = po("fcst.lag", lags = 1:12) %>>%
ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
Expand All @@ -371,7 +371,7 @@ graph = po("fcst.lag", lag = 1:12) %>>%
task = tsk("airpassengers")
flrn = ForecastRecursiveLearner$new(lrn("regr.ranger"))
glrn = as_learner(graph %>>% flrn)
trafo = po("fcst.targetdiff", lag = 12L)
trafo = po("fcst.targetdiff", lags = 12L)
pipeline = ppl("targettrafo", graph = glrn, trafo_pipeop = trafo)
glrn = as_learner(pipeline)$train(task)
prediction = glrn$predict(task, 142:144)
Expand Down
8 changes: 4 additions & 4 deletions man/ForecastLearner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_fcst.lag.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit aff5f76

Please sign in to comment.