Skip to content

Commit

Permalink
Merge pull request #498 from tidymodels/initial_split-attributes
Browse files Browse the repository at this point in the history
Retain split args for `initial_split()` objects
  • Loading branch information
hfrick authored May 28, 2024
2 parents f228edf + 64fed0e commit 9a9cb81
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
27 changes: 27 additions & 0 deletions R/initial_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,17 @@ initial_split <- function(data, prop = 3 / 4,
pool = pool,
times = 1
)
attrib <- .get_split_args(res, allow_strata_false = TRUE)

res <- res$splits[[1]]

attrib$times <- NULL
for (i in names(attrib)) {
attr(res, i) <- attrib[[i]]
}

class(res) <- c("initial_split", class(res))

res
}

Expand Down Expand Up @@ -83,6 +92,15 @@ initial_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {
rset <- new_rset(splits, ids)

res <- rset$splits[[1]]

attrib <- list(
prop = prop,
lag = lag
)
for (i in names(attrib)) {
attr(res, i) <- attrib[[i]]
}

class(res) <- c("initial_time_split", "initial_split", class(res))
res
}
Expand Down Expand Up @@ -154,7 +172,16 @@ group_initial_split <- function(data, group, prop = 3 / 4, ..., strata = NULL, p
pool = pool
)
}

attrib <- .get_split_args(res, allow_strata_false = TRUE)

res <- res$splits[[1]]

attrib$times <- NULL
for (i in names(attrib)) {
attr(res, i) <- attrib[[i]]
}
class(res) <- c("group_initial_split", "initial_split", class(res))

res
}
12 changes: 8 additions & 4 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,15 @@ non_random_classes <- c(
)

#' Get the split arguments from an rset
#' @param rset An `rset` object.
#' @param x An `rset` or `initial_split` object.
#' @param allow_strata_false A logical to specify which value to use if no
#' stratification was specified. The default is to use `strata = NULL`, the
#' alternative is `strata = FALSE`.
#' @return A list of arguments used to create the rset.
#' @keywords internal
#' @export
.get_split_args <- function(rset) {
all_attributes <- attributes(rset)
.get_split_args <- function(x, allow_strata_false = FALSE) {
all_attributes <- attributes(x)
function_used_to_create <- switch(
all_attributes$class[[1]],
"validation_set" = "initial_validation_split",
Expand All @@ -312,7 +315,8 @@ non_random_classes <- c(
args <- names(formals(function_used_to_create))
split_args <- all_attributes[args]
split_args <- split_args[!is.na(names(split_args))]
if (identical(split_args$strata, FALSE)) {

if (identical(split_args$strata, FALSE) && !allow_strata_false) {
split_args$strata <- NULL
}
split_args
Expand Down
8 changes: 6 additions & 2 deletions man/dot-get_split_args.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9a9cb81

Please sign in to comment.