diff --git a/DESCRIPTION b/DESCRIPTION index 781eae4..a2ba3fe 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,12 +19,10 @@ Suggests: jsonlite, kknn, parsnip, - recipes (>= 1.0.10.9000), + recipes, testthat (>= 3.0.0), tidypredict, workflows -Remotes: - tidymodels/recipes Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) diff --git a/R/orbital.R b/R/orbital.R index 91a9411..69c70f6 100644 --- a/R/orbital.R +++ b/R/orbital.R @@ -33,6 +33,11 @@ orbital.default <- function(x, ...) { ) } +new_orbital_class <- function(x) { + class(x) <- "orbital_class" + x +} + #' @export print.orbital_class <- function(x, ...) { x <- unclass(x) diff --git a/R/recipes.R b/R/recipes.R index b21e109..024fe5b 100644 --- a/R/recipes.R +++ b/R/recipes.R @@ -6,11 +6,8 @@ orbital.recipe <- function(x, eqs = NULL, ...) { } if (is.null(eqs)) { - ptype <- recipes::recipes_ptype(x, stage = "bake") - if (is.null(ptype)) { - cli::cli_abort("recipe must be created using version 1.1.0 or later.") - } - all_vars <- names(ptype) + terms <- x$term_info + all_vars <- terms$variable[terms$role == "predictor"] } else { all_vars <- all.vars(rlang::parse_expr(eqs)) } @@ -46,54 +43,9 @@ orbital.recipe <- function(x, eqs = NULL, ...) { } } - new_orbital_class(out) -} - -#' @export -orbital.step_pca <- function(x, all_vars, ...) { - rot <- x$res$rotation - colnames(rot) <- recipes::names0(ncol(rot), x$prefix) - - used_vars <- colnames(rot) %in% all_vars - - rot <- rot[, used_vars] - - row_nms <- rownames(rot) - - out <- character(length(all_vars)) - for (i in seq_along(all_vars)) { - out[i] <- paste(row_nms, "*", rot[, i], collapse = " + ") + if (is.null(out)) { + out <- character() } - names(out) <- all_vars - out -} - -#' @export -orbital.step_normalize <- function(x, all_vars, ...) { - means <- x$means - sds <- x$sds - - used_vars <- names(means) %in% all_vars - means <- means[used_vars] - sds <- sds[used_vars] - - out <- paste0("(", names(means), " - ", means ,") / ", sds) - names(out) <- names(means) - out -} - -#' @export -orbital.step_nzv <- function(x, all_vars, ...) { - NULL -} - -#' @export -orbital.step_corr <- function(x, all_vars, ...) { - NULL -} - -new_orbital_class <- function(x) { - class(x) <- "orbital_class" - x + new_orbital_class(out) } diff --git a/R/step_corr.R b/R/step_corr.R new file mode 100644 index 0000000..51916d9 --- /dev/null +++ b/R/step_corr.R @@ -0,0 +1,4 @@ +#' @export +orbital.step_corr <- function(x, all_vars, ...) { + NULL +} diff --git a/R/step_normalize.R b/R/step_normalize.R new file mode 100644 index 0000000..84ead4e --- /dev/null +++ b/R/step_normalize.R @@ -0,0 +1,13 @@ +#' @export +orbital.step_normalize <- function(x, all_vars, ...) { + means <- x$means + sds <- x$sds + + used_vars <- names(means) %in% all_vars + means <- means[used_vars] + sds <- sds[used_vars] + + out <- paste0("(", names(means), " - ", means ,") / ", sds) + names(out) <- names(means) + out +} \ No newline at end of file diff --git a/R/step_nzv.R b/R/step_nzv.R new file mode 100644 index 0000000..2d4111a --- /dev/null +++ b/R/step_nzv.R @@ -0,0 +1,4 @@ +#' @export +orbital.step_nzv <- function(x, all_vars, ...) { + NULL +} diff --git a/R/step_pca.R b/R/step_pca.R new file mode 100644 index 0000000..bb8af18 --- /dev/null +++ b/R/step_pca.R @@ -0,0 +1,24 @@ +#' @export +orbital.step_pca <- function(x, all_vars, ...) { + rot <- x$res$rotation + colnames(rot) <- recipes::names0(ncol(rot), x$prefix) + + used_vars <- pca_naming(colnames(rot), x$prefix) %in% + pca_naming(all_vars, x$prefix) + + rot <- rot[, used_vars] + + row_nms <- rownames(rot) + + out <- character(length(all_vars)) + for (i in seq_along(all_vars)) { + out[i] <- paste(row_nms, "*", rot[, i], collapse = " + ") + } + + names(out) <- all_vars + out +} + +pca_naming <- function(x, prefix) { + gsub(paste0(prefix, "0"), prefix, x) +} diff --git a/tests/testthat/test-step_corr.R b/tests/testthat/test-step_corr.R new file mode 100644 index 0000000..0081051 --- /dev/null +++ b/tests/testthat/test-step_corr.R @@ -0,0 +1,20 @@ +test_that("step_corr works", { + skip_if_not_installed("recipes") + + mtcars0 <- mtcars + mtcars0$disp1 <- mtcars$disp + + rec_exp <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_corr(recipes::all_predictors()) %>% + recipes::prep() + + expect_null(orbital(rec_exp$steps[[1]])) + + rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>% + recipes::step_corr(recipes::all_predictors()) %>% + recipes::prep() + + expect_null(orbital(rec$steps[[1]])) + + expect_identical(orbital(rec), orbital(rec_exp)) +}) diff --git a/tests/testthat/test-step_normalize.R b/tests/testthat/test-step_normalize.R new file mode 100644 index 0000000..b3b94bd --- /dev/null +++ b/tests/testthat/test-step_normalize.R @@ -0,0 +1,16 @@ +test_that("step_normalize works", { + skip_if_not_installed("recipes") + + mtcars <- dplyr::as_tibble(mtcars) + + rec <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_normalize(recipes::all_predictors()) %>% + recipes::prep() + + res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec))) + + exp <- recipes::bake(rec, new_data = mtcars) + exp <- exp[names(res)] + + expect_equal(res, exp) +}) diff --git a/tests/testthat/test-step_nzv.R b/tests/testthat/test-step_nzv.R new file mode 100644 index 0000000..1bf3bb8 --- /dev/null +++ b/tests/testthat/test-step_nzv.R @@ -0,0 +1,20 @@ +test_that("step_nzv works", { + skip_if_not_installed("recipes") + + mtcars0 <- mtcars + mtcars0$zv <- 0 + + rec_exp <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_nzv(recipes::all_predictors()) %>% + recipes::prep() + + expect_null(orbital(rec_exp$steps[[1]])) + + rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>% + recipes::step_nzv(recipes::all_predictors()) %>% + recipes::prep() + + expect_null(orbital(rec$steps[[1]])) + + expect_identical(orbital(rec), orbital(rec_exp)) +}) diff --git a/tests/testthat/test-step_pca.R b/tests/testthat/test-step_pca.R new file mode 100644 index 0000000..3962443 --- /dev/null +++ b/tests/testthat/test-step_pca.R @@ -0,0 +1,34 @@ +test_that("step_pca works", { + skip_if_not_installed("recipes") + + mtcars <- dplyr::as_tibble(mtcars) + mtcars$hp <- NULL + + rec <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_pca(recipes::all_predictors()) %>% + recipes::prep() + + exp <- recipes::bake(rec, new_data = mtcars) + + res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec))) + res <- res[names(exp)] + + expect_equal(res, exp) +}) + +test_that("step_pca works with more than 9 PCs", { + skip_if_not_installed("recipes") + + mtcars <- dplyr::as_tibble(mtcars) + + rec <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_pca(recipes::all_predictors()) %>% + recipes::prep() + + exp <- recipes::bake(rec, new_data = mtcars) + + res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec))) + res <- res[names(exp)] + + expect_equal(res, exp) +})