Skip to content

Commit

Permalink
add conditional arg for predict (#438)
Browse files Browse the repository at this point in the history
* add conditional arg

* fix issue

* modify equivalence

* update docs

* fix additional issues

* update according to comments

* address comments

* update h_factor_ref

* [skip roxygen] [skip vbump] Roxygen Man Pages Auto Update

* Empty

* add return for many utility functions

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
clarkliming and github-actions[bot] authored Apr 24, 2024
1 parent b476fd5 commit 239fe10
Show file tree
Hide file tree
Showing 14 changed files with 168 additions and 30 deletions.
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# mmrm 0.3.11.9000

### New Features

- Add parameter `conditional` for `predict` method to control whether the prediction is conditional on the observation or not.

### Bug Fixes

- Previously if the left hand side of a model formula is an expression, `predict` and `simulate` will fail. This is fixed now.

# mmrm 0.3.11

### Bug Fixes
Expand Down
28 changes: 12 additions & 16 deletions R/tmb-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fitted.mmrm_tmb <- function(object, ...) {
#' @param interval (`string`)\cr type of interval calculation. Can be abbreviated.
#' @param level (`number`)\cr tolerance/confidence level.
#' @param nsim (`count`)\cr number of simulations to use.
#' @param conditional (`flag`)\cr indicator if the prediction is conditional on the observation or not.
#'
#' @importFrom stats predict
#' @exportS3Method
Expand All @@ -66,6 +67,7 @@ predict.mmrm_tmb <- function(object,
interval = c("none", "confidence", "prediction"),
level = 0.95,
nsim = 1000L,
conditional = TRUE,
...) {
if (missing(newdata)) {
newdata <- object$data
Expand All @@ -75,23 +77,22 @@ predict.mmrm_tmb <- function(object,
assert_flag(se.fit)
assert_number(level, lower = 0, upper = 1)
assert_count(nsim, positive = TRUE)
assert_flag(conditional)
interval <- match.arg(interval)
# make sure new data has the same levels as original data
full_frame <- model.frame(
object,
data = newdata,
include = c("subject_var", "visit_var", "group_var", "response_var"),
na.action = "na.pass"
)
newdata <- h_factor_ref_data(object, newdata)
tmb_data <- h_mmrm_tmb_data(
object$formula_parts, full_frame,
weights = rep(1, nrow(full_frame)),
object$formula_parts, newdata,
weights = rep(1, nrow(newdata)),
reml = TRUE,
singular = "keep",
drop_visit_levels = FALSE,
allow_na_response = TRUE,
drop_levels = FALSE
)
if (!conditional) {
tmb_data$y_vector[] <- NA_real_
}
if (any(object$tmb_data$x_cols_aliased)) {
warning(
"In fitted object there are co-linear variables and therefore dropped terms, ",
Expand Down Expand Up @@ -610,15 +611,10 @@ simulate.mmrm_tmb <- function(object,
method <- match.arg(method)

# Ensure new data has the same levels as original data.
full_frame <- model.frame(
object,
data = newdata,
include = c("subject_var", "visit_var", "group_var", "response_var"),
na.action = "na.pass"
)
newdata <- h_factor_ref_data(object, newdata)
tmb_data <- h_mmrm_tmb_data(
object$formula_parts, full_frame,
weights = rep(1, nrow(full_frame)),
object$formula_parts, newdata,
weights = rep(1, nrow(newdata)),
reml = TRUE,
singular = "keep",
drop_visit_levels = FALSE,
Expand Down
21 changes: 11 additions & 10 deletions R/tmb.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#' - `subject_var`: `string` with the subject variable name.
#' - `group_var`: `string` with the group variable name. If no group specified,
#' this element is `NULL`.
#' - `model_var`: `character` with the variables names of the formula, except `subject_var`.
#'
#' @keywords internal
h_mmrm_tmb_formula_parts <- function(
Expand All @@ -37,7 +38,8 @@ h_mmrm_tmb_formula_parts <- function(
is_spatial = covariance$type == "sp_exp",
visit_var = covariance$visits,
subject_var = covariance$subject,
group_var = if (length(covariance$group) < 1) NULL else covariance$group
group_var = if (length(covariance$group) < 1) NULL else covariance$group,
model_var = setdiff(all.vars(formula[[3]]), covariance$subject)
),
class = "mmrm_tmb_formula_parts"
)
Expand Down Expand Up @@ -135,28 +137,27 @@ h_mmrm_tmb_data <- function(formula_parts,
# Weights is always the last column.
weights_name <- colnames(data)[ncol(data)]
# If `y` is allowed to be NA, then first replace y with 1:n, then replace it with original y.
if (allow_na_response) {
y_original <- eval(formula_parts$full_formula[[2]], envir = data)
vn <- deparse(formula_parts$full_formula[[2]])
data[[vn]] <- seq_len(nrow(data))
} else {
if (!allow_na_response) {
h_warn_na_action()
}
full_frame <- eval(
bquote(stats::model.frame(
formula_parts$full_formula,
data = data,
weights = .(as.symbol(weights_name)),
na.action = stats::na.omit
na.action = "na.pass"
))
)
if (drop_levels) {
full_frame <- droplevels(full_frame, except = formula_parts$visit_var)
}
# If `y` is allowed to be NA, replace it with original y.
if (allow_na_response) {
full_frame[[vn]] <- y_original[full_frame[[vn]]]
# response is always the first column
keep_ind <- complete.cases(full_frame[, -1L, drop = FALSE])
} else {
keep_ind <- complete.cases(full_frame)
}
full_frame <- full_frame[keep_ind, ]
if (drop_visit_levels && !formula_parts$is_spatial && is.factor(full_frame[[formula_parts$visit_var]])) {
old_levels <- levels(full_frame[[formula_parts$visit_var]])
full_frame[[formula_parts$visit_var]] <- droplevels(full_frame[[formula_parts$visit_var]])
Expand All @@ -166,6 +167,7 @@ h_mmrm_tmb_data <- function(formula_parts,
message("In ", formula_parts$visit_var, " there are dropped visits: ", toString(dropped))
}
}

x_matrix <- stats::model.matrix(formula_parts$model_formula, data = full_frame)
x_cols_aliased <- stats::setNames(rep(FALSE, ncol(x_matrix)), nm = colnames(x_matrix))
qr_x_mat <- qr(x_matrix)
Expand All @@ -186,7 +188,6 @@ h_mmrm_tmb_data <- function(formula_parts,
attr(x_matrix, "contrasts") <- contrasts_attr
}
}

y_vector <- as.numeric(stats::model.response(full_frame))
weights_vector <- as.numeric(stats::model.weights(full_frame))
n_subjects <- length(unique(full_frame[[formula_parts$subject_var]]))
Expand Down
28 changes: 28 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ h_partial_fun_args <- function(fun, ..., additional_attr = list()) {
#' For "Kenward-Roger" only, "Kenward-Roger" is returned.
#' For "Residual" only, "Empirical" is returned.
#'
#' @return String of the default covariance method.
#' @keywords internal
h_get_cov_default <- function(method = c("Satterthwaite", "Kenward-Roger", "Residual", "Between-Within")) {
assert_string(method)
Expand Down Expand Up @@ -264,6 +265,7 @@ drop_elements <- function(x, n) {
#'
#' @param x (`numeric`)\cr number of visit levels.
#'
#' @return Logical value `TRUE`.
#' @keywords internal
h_confirm_large_levels <- function(x) {
assert_count(x)
Expand Down Expand Up @@ -314,6 +316,7 @@ h_default_value <- function(x, y) {
#' This is needed even if `x` and `ref` are both `character` because
#' in `model.matrix` if `x` only has one level there could be errors.
#'
#' @return Factor vector with updated levels.
#' @keywords internal
h_factor_ref <- function(x, ref, var_name = vname(x)) {
assert_multi_class(ref, c("character", "factor"))
Expand All @@ -327,6 +330,30 @@ h_factor_ref <- function(x, ref, var_name = vname(x)) {
factor(x, levels = h_default_value(levels(ref), sort(uni_ref)))
}

#' Convert Character to Factor Following Reference `MMRM` Fit.
#'
#' @param object (`mmrm_tmb`)\cr the fitted MMRM object.
#' @param data (`data.frame`)\cr input data.
#'
#' @details Use fitted mmrm object to convert input data frame whose factors
#' are of the same levels as the reference fitted object.
#'
#' @return Data frame with updated levels in specified columns.
#' @keywords internal
h_factor_ref_data <- function(object, data) {
assert_data_frame(data)
assert_class(object, "mmrm_tmb")
ref <- object$tmb_data$full_frame
vars <- object$formula_parts$model_var

for (v in vars) {
if (is.factor(ref[[v]]) || is.character(ref[[v]])) {
data[[v]] <- h_factor_ref(data[[v]], ref[[v]])
}
}
data
}

#' Warn on na.action
#' @keywords internal
h_warn_na_action <- function() {
Expand Down Expand Up @@ -446,6 +473,7 @@ emp_start <- function(data, model_formula, visit_var, subject_var, subject_group
#' If the covariance matrix has `NA` in some of the elements, they will be replaced by
#' 0 (non-diagonal) and 1 (diagonal). This ensures that the matrix is positive definite.
#'
#' @return Numeric vector of the theta values.
#' @keywords internal
h_get_theta_from_cov <- function(covariance) {
assert_matrix(covariance, mode = "numeric", ncols = nrow(covariance))
Expand Down
3 changes: 3 additions & 0 deletions man/h_confirm_large_levels.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/h_factor_ref.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/h_factor_ref_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/h_get_cov_default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/h_get_theta_from_cov.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/h_mmrm_tmb_formula_parts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/mmrm_tmb_methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions tests/testthat/test-tmb-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,41 @@ test_that("predict will return NA if data contains NA in covariates", {
)
})

test_that("predict can give unconditional predictions", {
fit <- get_mmrm()
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
m <- stats::model.matrix(
fit$formula_parts$model_formula,
model.frame(fit, data = fev_data, include = "response_var", na.action = "na.pass")
)
expect_equal(
p,
(m %*% fit$beta_est)[, 1],
tolerance = 1e-7
)
})

test_that("predict can change based on coefficients", {
fit <- get_mmrm()
new_beta <- coef(fit) + 0.1
fit$beta_est <- new_beta
m <- stats::model.matrix(
fit$formula_parts$model_formula,
model.frame(fit, data = fev_data, include = "response_var", na.action = "na.pass")
)
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
expect_equal(
p,
(m %*% new_beta)[, 1],
tolerance = 1e-7
)
})

test_that("predict can work if response is an expression", {
fit <- mmrm(log(FEV1) + FEV1 ~ ARMCD * AVISIT + ar1(AVISIT | USUBJID), data = fev_data)
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
})

## integration test with SAS ----

test_that("predict gives same result with sas in unstructured satterthwaite/Kenward-Roger", {
Expand Down Expand Up @@ -865,6 +900,13 @@ test_that("response residuals helper function works as expected", {

# simulate.mmrm_tmb ----

test_that("simulate works if the model reponse is an expression", {
object <- mmrm(log(FEV1) + FEV1 ~ ARMCD * AVISIT + ar1(AVISIT | USUBJID), data = fev_data)
set.seed(1001)
sims <- simulate(object, nsim = 2, method = "conditional")
expect_data_frame(sims, any.missing = FALSE, nrows = nrow(object$data), ncols = 2)
})

test_that("simulate with conditional method returns a df of correct dimension", {
object <- get_mmrm()
set.seed(1001)
Expand Down
12 changes: 8 additions & 4 deletions tests/testthat/test-tmb.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ test_that("h_mmrm_tmb_formula_parts works as expected", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = NULL
group_var = NULL,
model_var = c("RACE", "SEX", "ARMCD", "AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand All @@ -53,7 +54,8 @@ test_that("h_mmrm_tmb_formula_parts works as expected", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = "ARMCD"
group_var = "ARMCD",
model_var = c("RACE", "SEX", "ARMCD", "AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand Down Expand Up @@ -115,7 +117,8 @@ test_that("h_mmrm_tmb_formula_parts works without covariates", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = NULL
group_var = NULL,
model_var = c("AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand All @@ -135,7 +138,8 @@ test_that("h_mmrm_tmb_formula_parts works as expected for antedependence", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = NULL
group_var = NULL,
model_var = c("RACE", "SEX", "ARMCD", "AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand Down
Loading

0 comments on commit 239fe10

Please sign in to comment.