Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarizing HTE Outputs in a Multi-Arm Experiment #1479

Open
shafayetShafee opened this issue Jan 12, 2025 · 1 comment
Open

Summarizing HTE Outputs in a Multi-Arm Experiment #1479

shafayetShafee opened this issue Jan 12, 2025 · 1 comment
Labels

Comments

@shafayetShafee
Copy link

shafayetShafee commented Jan 12, 2025

Hello, I need some guidance/direction/suggestions on how can I use the estimated HTE outputs from the multi_arm_causal_forest to create insightful summary. After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.

Lets consider a reproducible example to discuss the approaches,

Setups

library(grf)
library(dplyr)
library(ggplot2)

set.seed(1344)

Helper fns

predict_effect_and_ci <- function(multi_arm_causal_forest_model, newdata = NULL) {

  if (!inherits(multi_arm_causal_forest_model, "multi_arm_causal_forest")) {
    stop('This function only supports model objects of class "multi_arm_causal_forest".')
  }

  tau_hat <- predict(
    multi_arm_causal_forest_model,
    newdata = newdata,
    estimate.variance = TRUE,
    drop = TRUE
  )

  effect_estimate_df <- as.data.frame(tau_hat$predictions)
  contrasts_name <- colnames(effect_estimate_df)
  contrast_generic_name <- paste0("contrast_", seq(1, length(contrasts_name)))
  contrast_info <- setNames(contrasts_name, contrast_generic_name)
  colnames(effect_estimate_df) <- paste0(contrast_generic_name, "_estimate")

  effect_estimate_var_df <- as.data.frame(tau_hat$variance.estimates)
  colnames(effect_estimate_var_df) <- paste0(contrast_generic_name, "_var")

  effect_est_df <- bind_cols(effect_estimate_df, effect_estimate_var_df)

  return(list(
    contrast_info = contrast_info,
    data = effect_est_df
  ))
}

get_top_n_vars <- function(forest, X, n = 3) {
  varimp <- grf::variable_importance(forest)
  ranked_variables <- order(varimp, decreasing = TRUE)
  top_varnames <- colnames(X)[ranked_variables[1:n]]
  return(top_varnames)
}
n <- 3000
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE))
Y <- X[, 1] + X[, 2] * (W == "B") - 1.5 * X[, 2] * (W == "C") + rnorm(n)

exp_df <- data.frame(Y = Y, W = W, X)

Splitting Data into Train-Test

train = sample(nrow(X), 0.6 * nrow(X))
test = -train

Fit Forest Model on Training Set

mc.forest <- multi_arm_causal_forest(X[train, ], Y[train], W[train], seed = 1344)

Predict HTEs on Test Set

tau_hat_est <- predict_effect_and_ci(mc.forest, newdata = X[test, ])
tau_hat_est_df <- bind_cols(tau_hat_est$data, exp_df[test, ]) %>% 
  mutate(
    c1_ci_low = contrast_1_estimate - 1.96 * sqrt(contrast_1_var),
    c1_ci_high = contrast_1_estimate + 1.96 * sqrt(contrast_1_var),
    c2_ci_low = contrast_2_estimate - 1.96 * sqrt(contrast_2_var),
    c2_ci_high = contrast_2_estimate + 1.96 * sqrt(contrast_2_var),
  )

head(tau_hat_est_df, 3)
  contrast_1_estimate contrast_2_estimate contrast_1_var contrast_2_var
1         -0.06310611          -0.1661815     0.01636501     0.01322940
2         -0.56899703           0.9466801     0.02781945     0.05492243
3         -0.49811420           0.9974250     0.02565718     0.06470612
           Y W          X1         X2          X3         X4          X5
1  0.9849406 A  0.54756844 -0.1014569  0.21716754 -2.0556520 -0.04809347
2 -2.1574977 A  0.08431498 -0.6160837 -0.46033781 -0.1537932  0.08784540
3  1.0177696 A -0.50059754 -0.6376908 -0.08594392  0.4529726 -1.98854317
          X6          X7         X8         X9        X10  c1_ci_low c1_ci_high
1  1.8194223 -0.04598789 -0.3885001 0.45111597 -1.9751646 -0.3138407  0.1876284
2 -0.8248944 -1.42140442 -0.8348958 0.06918902  0.8410156 -0.8959086 -0.2420854
3 -0.5929628  0.08853166  0.1790741 0.92633845  0.8261464 -0.8120642 -0.1841642
   c2_ci_low c2_ci_high
1 -0.3916190 0.05925604
2  0.4873436 1.40601655
3  0.4988520 1.49599794

Creating HTE Quartile Groups

The tau_hat_est_df contains two HTE estimates, $\hat{\tau}{b-a}$ comparing treatment “B” with “A” and $\hat{\tau}{c-a}$ comparing treatment “C” with “A”. We can create quartile groups based on $\hat{\tau}_{b-a}$, at first.

num.groups = 4

quartile = cut(
  tau_hat_est_df$contrast_1_estimate,
  quantile(tau_hat_est_df$contrast_1_estimate, seq(0, 1, by = 1 / num.groups)),
  labels = 1:num.groups,
  include.lowest = TRUE
)

samples.by.quartile = split(seq_along(quartile), quartile)

eval.forest = multi_arm_causal_forest(X[test, ], Y[test], W[test], seed = 1345)

ate.by.quartile = lapply(samples.by.quartile, function(samples) {
  average_treatment_effect(eval.forest, subset = samples)
})

df.plot.ate = bind_rows(ate.by.quartile, .id = "group") %>% 
  mutate(
    group = paste0("Q", group)
  ) %>% 
  select(group, contrast, estimate, std.err)
  
rownames(df.plot.ate) <- NULL

head(df.plot.ate, 10)
  group contrast   estimate   std.err
1    Q1    B - A -1.1571475 0.1582629
2    Q1    C - A  2.0199825 0.1652687
3    Q2    B - A -0.3894375 0.1472359
4    Q2    C - A  0.5131891 0.1584275
5    Q3    B - A  0.3887451 0.1407779
6    Q3    C - A -0.4980314 0.1409159
7    Q4    B - A  1.2435839 0.1512048
8    Q4    C - A -1.9012597 0.1586695
tau_BA_ate <- df.plot.ate %>% 
  filter(contrast == "B - A")

tau_BA_ate %>% 
ggplot(aes(x = group, y = estimate)) +
  geom_hline(yintercept = 0, linetype = 2, linewidth = 0.5) +
  geom_errorbar(
    aes(
      ymin = estimate - 1.96 * std.err, 
      ymax = estimate + 1.96 * std.err
    ),
    width = 0.09, color = "#4E79A7", linewidth = 0.7
  ) +
  geom_point(color = "#E15759", size = 3) +
  xlab("Estimated CATE Quartile") +
  ylab("Average treatment effect") + 
  theme_minimal() +
  theme(
    plot.title = element_text(size = 12, face = "bold", lineheight = 1.1),
    axis.text = element_text(size = 11),
    axis.title.x = element_text(margin = margin(t = 10))
  ) 

image

Note that, since I have created the quartile groups based on $\hat{\tau}{b-a}$, I have only used the ATE estimates (and its SE) for the “B - A” contrast and plotted them, ignoring the values for “C - A” contrast. But when $\hat{\tau}{c-a}$ will be used to create the quartile groups, only the ATE estimates for “C - A” contrast will be shown. At least, that what I am thinking. So my question is, Am I on the right track? Are there any better ways ?

Covariate Profiles for Quartile Groups

top_2_vars <- get_top_n_vars(
  mc.forest, 
  exp_df %>% select(starts_with("X")), 
  n = 2
)

top_2_vars
[1] "X2" "X5"
tau_hat_est_df %>% 
  mutate(
    Q = quartile,
    group = paste0("Q", Q)
  ) %>% 
  group_by(group) %>% 
  summarise(
    across(.cols = all_of(top_2_vars), .fns = mean, .names = "mean_{.col}")
  ) %>% 
  left_join(tau_BA_ate, by = "group")
# A tibble: 4 × 6
  group mean_X2 mean_X5 contrast estimate std.err
  <chr>   <dbl>   <dbl> <chr>       <dbl>   <dbl>
1 Q1     -1.28   0.124  B - A      -1.16    0.158
2 Q2     -0.301 -0.0852 B - A      -0.389   0.147
3 Q3      0.366 -0.138  B - A       0.389   0.141
4 Q4      1.30  -0.0303 B - A       1.24    0.151

Is the above summary representation valid? Are there any better ways?

Additional Questions

  1. Is it incorrect to average the $\hat{\tau}_{b-a}$ for each quartile, rather than fitting eval.forest to each quartile group separately to get the ATE estimates?
@erikcs
Copy link
Member

erikcs commented Jan 19, 2025

Hi @shafayetShafee,

After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.

The approaches described in that paper applies to any given treatment, and you could perform the same kind of separate analysis for each arm if you have several arms available. In principle you could fit a causal forest for each treatment arm and do this kind of analysis - you can just think of multi-arm causal forest as a way to jointly estimate all these CATEs instead of doing it separately which may increase estimation power if there is some HTE signal shared across arms.

For a kind of analysis tailored to multiple arms where there are different costs associated with deploying treatment, then this paper/package https://github.com/grf-labs/maq extends the Qini curve (Figure 5) to that setting. It essentially just allows you to translate predictions from multiple treatments into a treatment allocation policy that satisfies some budget constraint, then plot the value of this.

1: You should use the average_treatment_effect function to compute an ATE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants