Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conditional arg for predict #438

Merged
merged 11 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# mmrm 0.3.11.9000

### New Features

- Add parameter `conditional` for `predict` method to control whether the prediction is conditional on the observation or not.

### Bug Fixes

- Previously if the left hand side of a model formula is an expression, `predict` and `simulate` will fail. This is fixed now.

# mmrm 0.3.11

### Bug Fixes
Expand Down
32 changes: 18 additions & 14 deletions R/tmb-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fitted.mmrm_tmb <- function(object, ...) {
#' @param interval (`string`)\cr type of interval calculation. Can be abbreviated.
#' @param level (`number`)\cr tolerance/confidence level.
#' @param nsim (`count`)\cr number of simulations to use.
#' @param conditional (`flag`)\cr indicator if the prediction is conditional on the observation or not.
#'
#' @importFrom stats predict
#' @exportS3Method
Expand All @@ -66,6 +67,7 @@ predict.mmrm_tmb <- function(object,
interval = c("none", "confidence", "prediction"),
level = 0.95,
nsim = 1000L,
conditional = TRUE,
...) {
if (missing(newdata)) {
newdata <- object$data
Expand All @@ -75,23 +77,26 @@ predict.mmrm_tmb <- function(object,
assert_flag(se.fit)
assert_number(level, lower = 0, upper = 1)
assert_count(nsim, positive = TRUE)
assert_flag(conditional)
interval <- match.arg(interval)
# make sure new data has the same levels as original data
full_frame <- model.frame(
object,
data = newdata,
include = c("subject_var", "visit_var", "group_var", "response_var"),
na.action = "na.pass"
newdata <- h_factor_ref_data(
newdata,
object$tmb_data$full_frame,
h_mmrm_vars(object$formula_parts)
)
tmb_data <- h_mmrm_tmb_data(
object$formula_parts, full_frame,
weights = rep(1, nrow(full_frame)),
object$formula_parts, newdata,
weights = rep(1, nrow(newdata)),
reml = TRUE,
singular = "keep",
drop_visit_levels = FALSE,
allow_na_response = TRUE,
drop_levels = FALSE
)
if (!conditional) {
tmb_data$y_vector[] <- NA_real_
}
if (any(object$tmb_data$x_cols_aliased)) {
warning(
"In fitted object there are co-linear variables and therefore dropped terms, ",
Expand Down Expand Up @@ -610,15 +615,14 @@ simulate.mmrm_tmb <- function(object,
method <- match.arg(method)

# Ensure new data has the same levels as original data.
full_frame <- model.frame(
object,
data = newdata,
include = c("subject_var", "visit_var", "group_var", "response_var"),
na.action = "na.pass"
newdata <- h_factor_ref_data(
newdata,
object$tmb_data$full_frame,
h_mmrm_vars(object$formula_parts)
)
tmb_data <- h_mmrm_tmb_data(
object$formula_parts, full_frame,
weights = rep(1, nrow(full_frame)),
object$formula_parts, newdata,
weights = rep(1, nrow(newdata)),
reml = TRUE,
singular = "keep",
drop_visit_levels = FALSE,
Expand Down
11 changes: 6 additions & 5 deletions R/tmb.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ h_mmrm_tmb_data <- function(formula_parts,
if (allow_na_response) {
y_original <- eval(formula_parts$full_formula[[2]], envir = data)
vn <- deparse(formula_parts$full_formula[[2]])
data[[vn]] <- seq_len(nrow(data))
} else {
h_warn_na_action()
}
Expand All @@ -147,16 +146,18 @@ h_mmrm_tmb_data <- function(formula_parts,
formula_parts$full_formula,
data = data,
weights = .(as.symbol(weights_name)),
na.action = stats::na.omit
na.action = "na.pass"
))
)
if (drop_levels) {
full_frame <- droplevels(full_frame, except = formula_parts$visit_var)
}
# If `y` is allowed to be NA, replace it with original y.
if (allow_na_response) {
full_frame[[vn]] <- y_original[full_frame[[vn]]]
keep_ind <- complete.cases(full_frame[, colnames(full_frame) != vn])
} else {
keep_ind <- complete.cases(full_frame)
}
full_frame <- full_frame[keep_ind, ]
if (drop_visit_levels && !formula_parts$is_spatial && is.factor(full_frame[[formula_parts$visit_var]])) {
old_levels <- levels(full_frame[[formula_parts$visit_var]])
full_frame[[formula_parts$visit_var]] <- droplevels(full_frame[[formula_parts$visit_var]])
Expand All @@ -166,6 +167,7 @@ h_mmrm_tmb_data <- function(formula_parts,
message("In ", formula_parts$visit_var, " there are dropped visits: ", toString(dropped))
}
}

x_matrix <- stats::model.matrix(formula_parts$model_formula, data = full_frame)
x_cols_aliased <- stats::setNames(rep(FALSE, ncol(x_matrix)), nm = colnames(x_matrix))
qr_x_mat <- qr(x_matrix)
Expand All @@ -186,7 +188,6 @@ h_mmrm_tmb_data <- function(formula_parts,
attr(x_matrix, "contrasts") <- contrasts_attr
}
}

y_vector <- as.numeric(stats::model.response(full_frame))
weights_vector <- as.numeric(stats::model.weights(full_frame))
n_subjects <- length(unique(full_frame[[formula_parts$subject_var]]))
Expand Down
34 changes: 34 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,40 @@ h_factor_ref <- function(x, ref, var_name = vname(x)) {
factor(x, levels = h_default_value(levels(ref), sort(uni_ref)))
}

#' Convert Character to Factor Following Reference in a Data Frame
#'
#' @param x (`data.frame`)\cr input.
#' @param ref (`data.frame`)\cr reference.
#' @param vars (`character`)\cr variable names to convert.
#'
#' @details Use `ref` to convert `x` into data frame whose factors
#' are of the same levels.
#'
#' @keywords internal
h_factor_ref_data <- function(x, ref, vars) {
assert_character(vars)
assert_data_frame(x)
assert_data_frame(ref)
assert_names(colnames(x), must.include = vars)
assert_names(colnames(ref), must.include = vars)
for (v in vars) {
if (is.factor(ref[[v]]) || is.character(ref[[v]])) {
x[[v]] <- h_factor_ref(x[[v]], ref[[v]])
}
}
x
}

#' Obtain Right Hand Side Variables of `mmrm_tmb_formula_parts` Object
#'
#' @param object (`mmrm_tmb_formula_parts`)\cr object.
#'
#' @keywords internal
h_mmrm_vars <- function(object) {
assert_class(object, "mmrm_tmb_formula_parts")
setdiff(all.vars(object$formula[[3]]), object$subject_var)
}

#' Warn on na.action
#' @keywords internal
h_warn_na_action <- function() {
Expand Down
23 changes: 23 additions & 0 deletions man/h_factor_ref_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions man/h_mmrm_vars.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/mmrm_tmb_methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions tests/testthat/test-tmb-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,41 @@ test_that("predict will return NA if data contains NA in covariates", {
)
})

test_that("predict can give unconditional predictions", {
fit <- get_mmrm()
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
m <- stats::model.matrix(
fit$formula_parts$model_formula,
model.frame(fit, data = fev_data, include = "response_var", na.action = "na.pass")
)
expect_equal(
p,
(m %*% fit$beta_est)[, 1],
tolerance = 1e-7
)
})

test_that("predict can change based on coefficients", {
fit <- get_mmrm()
new_beta <- coef(fit) + 0.1
fit$beta_est <- new_beta
m <- stats::model.matrix(
fit$formula_parts$model_formula,
model.frame(fit, data = fev_data, include = "response_var", na.action = "na.pass")
)
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
expect_equal(
p,
(m %*% new_beta)[, 1],
tolerance = 1e-7
)
})

test_that("predict can work if response is an expression", {
fit <- mmrm(log(FEV1) + FEV1 ~ ARMCD * AVISIT + ar1(AVISIT | USUBJID), data = fev_data)
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
})

## integration test with SAS ----

test_that("predict gives same result with sas in unstructured satterthwaite/Kenward-Roger", {
Expand Down Expand Up @@ -865,6 +900,13 @@ test_that("response residuals helper function works as expected", {

# simulate.mmrm_tmb ----

test_that("simulate works if the model reponse is an expression", {
object <- mmrm(log(FEV1) + FEV1 ~ ARMCD * AVISIT + ar1(AVISIT | USUBJID), data = fev_data)
set.seed(1001)
sims <- simulate(object, nsim = 2, method = "conditional")
expect_data_frame(sims, any.missing = FALSE, nrows = nrow(object$data), ncols = 2)
})

test_that("simulate with conditional method returns a df of correct dimension", {
object <- get_mmrm()
set.seed(1001)
Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,22 @@ test_that("h_factor ref allows NA in x", {
expect_identical(levels(f), ref)
})

# h_factor_ref_data ----

test_that("h_factor_ref_data works", {
df1 <- data.frame(
a = c("a", "b"),
b = c("a", "b")
)
df2 <- data.frame(
a = factor(c("a", "b", "c")),
b = factor(c("a", "d", NA_character_))
)
f <- expect_silent(h_factor_ref_data(df1, df2, c("a")))
expect_identical(levels(f$a), levels(df2$a))
expect_identical(f$b, c("a", "b"))
})

# std_start ----

test_that("std_start works", {
Expand Down Expand Up @@ -314,3 +330,11 @@ test_that("emp_start works", {
h_get_theta_from_cov(emp_mat)
)
})

# h_mmrm_vars ----

test_that("h_mmrm_vars works", {
fit <- get_mmrm()
expect_silent(v <- h_mmrm_vars(fit$formula_parts))
expect_identical(v, c("RACE", "SEX", "ARMCD", "AVISIT"))
})
Loading