Skip to content

Commit

Permalink
Merge pull request #6 from tidymodels/test-steps
Browse files Browse the repository at this point in the history
Test all implemented recipes steps
  • Loading branch information
EmilHvitfeldt authored Jun 25, 2024
2 parents 18997f1 + 14b42d4 commit 07f093f
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 56 deletions.
4 changes: 1 addition & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ Suggests:
jsonlite,
kknn,
parsnip,
recipes (>= 1.0.10.9000),
recipes,
testthat (>= 3.0.0),
tidypredict,
workflows
Remotes:
tidymodels/recipes
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
Expand Down
5 changes: 5 additions & 0 deletions R/orbital.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ orbital.default <- function(x, ...) {
)
}

new_orbital_class <- function(x) {
class(x) <- "orbital_class"
x
}

#' @export
print.orbital_class <- function(x, ...) {
x <- unclass(x)
Expand Down
58 changes: 5 additions & 53 deletions R/recipes.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ orbital.recipe <- function(x, eqs = NULL, ...) {
}

if (is.null(eqs)) {
ptype <- recipes::recipes_ptype(x, stage = "bake")
if (is.null(ptype)) {
cli::cli_abort("recipe must be created using version 1.1.0 or later.")
}
all_vars <- names(ptype)
terms <- x$term_info
all_vars <- terms$variable[terms$role == "predictor"]
} else {
all_vars <- all.vars(rlang::parse_expr(eqs))
}
Expand Down Expand Up @@ -46,54 +43,9 @@ orbital.recipe <- function(x, eqs = NULL, ...) {
}
}

new_orbital_class(out)
}

#' @export
orbital.step_pca <- function(x, all_vars, ...) {
rot <- x$res$rotation
colnames(rot) <- recipes::names0(ncol(rot), x$prefix)

used_vars <- colnames(rot) %in% all_vars

rot <- rot[, used_vars]

row_nms <- rownames(rot)

out <- character(length(all_vars))
for (i in seq_along(all_vars)) {
out[i] <- paste(row_nms, "*", rot[, i], collapse = " + ")
if (is.null(out)) {
out <- character()
}

names(out) <- all_vars
out
}

#' @export
orbital.step_normalize <- function(x, all_vars, ...) {
means <- x$means
sds <- x$sds

used_vars <- names(means) %in% all_vars
means <- means[used_vars]
sds <- sds[used_vars]

out <- paste0("(", names(means), " - ", means ,") / ", sds)
names(out) <- names(means)
out
}

#' @export
orbital.step_nzv <- function(x, all_vars, ...) {
NULL
}

#' @export
orbital.step_corr <- function(x, all_vars, ...) {
NULL
}

new_orbital_class <- function(x) {
class(x) <- "orbital_class"
x
new_orbital_class(out)
}
4 changes: 4 additions & 0 deletions R/step_corr.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#' @export
orbital.step_corr <- function(x, all_vars, ...) {
NULL
}
13 changes: 13 additions & 0 deletions R/step_normalize.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#' @export
orbital.step_normalize <- function(x, all_vars, ...) {
means <- x$means
sds <- x$sds

used_vars <- names(means) %in% all_vars
means <- means[used_vars]
sds <- sds[used_vars]

out <- paste0("(", names(means), " - ", means ,") / ", sds)
names(out) <- names(means)
out
}
4 changes: 4 additions & 0 deletions R/step_nzv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#' @export
orbital.step_nzv <- function(x, all_vars, ...) {
NULL
}
24 changes: 24 additions & 0 deletions R/step_pca.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#' @export
orbital.step_pca <- function(x, all_vars, ...) {
rot <- x$res$rotation
colnames(rot) <- recipes::names0(ncol(rot), x$prefix)

used_vars <- pca_naming(colnames(rot), x$prefix) %in%
pca_naming(all_vars, x$prefix)

rot <- rot[, used_vars]

row_nms <- rownames(rot)

out <- character(length(all_vars))
for (i in seq_along(all_vars)) {
out[i] <- paste(row_nms, "*", rot[, i], collapse = " + ")
}

names(out) <- all_vars
out
}

pca_naming <- function(x, prefix) {
gsub(paste0(prefix, "0"), prefix, x)
}
20 changes: 20 additions & 0 deletions tests/testthat/test-step_corr.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
test_that("step_corr works", {
skip_if_not_installed("recipes")

mtcars0 <- mtcars
mtcars0$disp1 <- mtcars$disp

rec_exp <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_corr(recipes::all_predictors()) %>%
recipes::prep()

expect_null(orbital(rec_exp$steps[[1]]))

rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>%
recipes::step_corr(recipes::all_predictors()) %>%
recipes::prep()

expect_null(orbital(rec$steps[[1]]))

expect_identical(orbital(rec), orbital(rec_exp))
})
16 changes: 16 additions & 0 deletions tests/testthat/test-step_normalize.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
test_that("step_normalize works", {
skip_if_not_installed("recipes")

mtcars <- dplyr::as_tibble(mtcars)

rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::prep()

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))

exp <- recipes::bake(rec, new_data = mtcars)
exp <- exp[names(res)]

expect_equal(res, exp)
})
20 changes: 20 additions & 0 deletions tests/testthat/test-step_nzv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
test_that("step_nzv works", {
skip_if_not_installed("recipes")

mtcars0 <- mtcars
mtcars0$zv <- 0

rec_exp <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_nzv(recipes::all_predictors()) %>%
recipes::prep()

expect_null(orbital(rec_exp$steps[[1]]))

rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>%
recipes::step_nzv(recipes::all_predictors()) %>%
recipes::prep()

expect_null(orbital(rec$steps[[1]]))

expect_identical(orbital(rec), orbital(rec_exp))
})
34 changes: 34 additions & 0 deletions tests/testthat/test-step_pca.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
test_that("step_pca works", {
skip_if_not_installed("recipes")

mtcars <- dplyr::as_tibble(mtcars)
mtcars$hp <- NULL

rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_pca(recipes::all_predictors()) %>%
recipes::prep()

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("step_pca works with more than 9 PCs", {
skip_if_not_installed("recipes")

mtcars <- dplyr::as_tibble(mtcars)

rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_pca(recipes::all_predictors()) %>%
recipes::prep()

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

0 comments on commit 07f093f

Please sign in to comment.