Skip to content

Commit

Permalink
Add interval argument for augment.lm
Browse files Browse the repository at this point in the history
- Also fix coeficient plot example in documentation
  • Loading branch information
grantmcdermott committed Aug 5, 2020
1 parent f4295f7 commit aed92af
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
30 changes: 20 additions & 10 deletions R/stats-lm-tidiers.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,32 @@
#' glance(mod)
#'
#' # coefficient plot
#' d <- tidy(mod) %>%
#' mutate(
#' low = estimate - std.error,
#' high = estimate + std.error
#' )
#' d <- tidy(mod, conf.int = TRUE)
#'
#' ggplot(d, aes(estimate, term, xmin = low, xmax = high, height = 0)) +
#' ggplot(d, aes(estimate, term, xmin = conf.low, xmax = conf.high, height = 0)) +
#' geom_point() +
#' geom_vline(xintercept = 0) +
#' geom_vline(xintercept = 0, lty = 4) +
#' geom_errorbarh()
#'
#' augment(mod)
#' augment(mod, mtcars)
#' augment(mod, mtcars, interval = "confidence")
#'
#' # predict on new data
#' newdata <- mtcars %>%
#' head(6) %>%
#' mutate(wt = wt + 1)
#' augment(mod, newdata = newdata)
#'
#' # ggplot2 example where we also construct 95% prediction interval
#' mod2 <- lm(mpg ~ wt, data = mtcars) ## simpler bivariate model since we're plotting in 2D
#'
#' au <- augment(mod2, newdata = newdata, interval = "prediction")
#'
#' ggplot(au, aes(wt, mpg)) +
#' geom_point() +
#' geom_line(aes(y = .fitted)) +
#' geom_ribbon(aes(ymin = .conf.low, ymax = .conf.high), col = NA, alpha = 0.3)
#'
#' # predict on new data without outcome variable. Output does not include .resid
#' newdata <- newdata %>%
#' select(-mpg)
Expand Down Expand Up @@ -103,9 +109,12 @@ tidy.lm <- function(x, conf.int = FALSE, conf.level = 0.95, ...) {
#' @template param_data
#' @template param_newdata
#' @template param_se_fit
#' @template param_interval
#'
#' @evalRd return_augment(
#' ".hat",
#' ".conf.low",
#' ".conf.high",
#' ".sigma",
#' ".cooksd",
#' ".se.fit",
Expand All @@ -122,8 +131,9 @@ tidy.lm <- function(x, conf.int = FALSE, conf.level = 0.95, ...) {
#' @seealso [augment()], [stats::predict.lm()]
#' @family lm tidiers
augment.lm <- function(x, data = model.frame(x), newdata = NULL,
se_fit = FALSE, ...) {
df <- augment_newdata(x, data, newdata, se_fit)
se_fit = FALSE, interval = c("none", "confidence", "prediction"), ...) {
interval <- match.arg(interval)
df <- augment_newdata(x, data, newdata, se_fit, interval)

if (is.null(newdata)) {
tryCatch({
Expand Down
42 changes: 34 additions & 8 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ add_hat_sigma_cols <- function(df, x, infl) {
# deal with rownames and convert to tibble as necessary
# add .se.fit column if present
# be *incredibly* careful that the ... are passed correctly
augment_newdata <- function(x, data, newdata, .se_fit, ...) {
augment_newdata <- function(x, data, newdata, .se_fit, interval, ...) {
passed_newdata <- !is.null(newdata)
df <- if (passed_newdata) newdata else data
df <- as_augment_tibble(df)
Expand All @@ -384,20 +384,46 @@ augment_newdata <- function(x, data, newdata, .se_fit, ...) {
# an na.pass argument

if (.se_fit) {
pred_obj <- predict(x, newdata = newdata, na.action = na.pass, se.fit = TRUE, ...)
df$.fitted <- pred_obj$fit %>% unname()

pred_obj <- predict(x, newdata = newdata, na.action = na.pass, se.fit = .se_fit, interval = interval, ...)
if (interval=="none") {
df$.fitted <- pred_obj$fit %>% unname()
} else {
df$.fitted <- pred_obj$fit[, "fit"]
df$.conf.low <- pred_obj$fit[, "lwr"]
df$.conf.high <- pred_obj$fit[, "upr"]
}

# a couple possible names for the standard error element of the list
# se.fit: lm, glm
# se: loess
se_idx <- which(names(pred_obj) %in% c("se.fit", "se"))
df$.se.fit <- pred_obj[[se_idx]]

} else if (interval!="none") {
pred_obj <- predict(x, newdata = newdata, na.action = na.pass, se.fit = FALSE, interval = interval, ...)
df$.fitted <- pred_obj[, "fit"]
df$.conf.low <- pred_obj[, "lwr"]
df$.conf.high <- pred_obj[, "upr"]
} else if (passed_newdata) {
df$.fitted <- predict(x, newdata = newdata, na.action = na.pass, ...) %>%
unname()
if (interval=="none") {
df$.fitted <- predict(x, newdata = newdata, na.action = na.pass, ...) %>%
unname()
} else {
pred_obj <- predict(x, newdata = newdata, na.action = na.pass, interval = interval, ...)
df$.fitted <- pred_obj$fit[, "fit"]
df$.conf.low <- pred_obj$fit[, "lwr"]
df$.conf.high <- pred_obj$fit[, "upr"]
}
} else {
df$.fitted <- predict(x, na.action = na.pass, ...) %>%
unname()
if (interval=="none") {
df$.fitted <- predict(x, na.action = na.pass, ...) %>%
unname()
} else {
pred_obj <- predict(x, newdata = newdata, na.action = na.pass, interval = interval, ...)
df$.fitted <- pred_obj$fit[, "fit"]
df$.conf.low <- pred_obj$fit[, "lwr"]
df$.conf.high <- pred_obj$fit[, "upr"]
}
}

# If response variable is not included in newdata, remove response variable
Expand Down
4 changes: 4 additions & 0 deletions man-roxygen/param_interval.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#' @param interval Character indicating the type of confidence interval columns
#' to be added to the augmented output. Passed on to `predict()` and defaults
#' to "none".
#' @md

0 comments on commit aed92af

Please sign in to comment.