Skip to content

Commit

Permalink
add mundlak event study
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvalal committed Aug 21, 2024
1 parent 4bd5a56 commit 66bf4dd
Show file tree
Hide file tree
Showing 2 changed files with 528 additions and 864 deletions.
261 changes: 253 additions & 8 deletions duckreg/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
self._parse_formula()

def _parse_formula(self):

lhs, rhs = self.formula.split("~")
rhs_deparsed = rhs.split("|")
covars, fevars = rhs.split("|") if len(rhs_deparsed) > 1 else (rhs, None)
Expand Down Expand Up @@ -79,7 +78,6 @@ def compress_data(self):
self.df_compressed.eval(create_means, inplace=True)

def collect_data(self, data: pd.DataFrame) -> pd.DataFrame:

y = data.filter(
regex=f"mean_{'(' + '|'.join(self.outcome_vars) + ')'}", axis=1
).values
Expand Down Expand Up @@ -127,7 +125,6 @@ def fit_vcov(self):
self.vcov = n_nk * (bread @ meat @ bread)

def estimate_feols(self):

if self.fevars:
fml = f"{'+'.join([f'mean_{x}' for x in self.outcome_vars])} ~ {' + '.join(self.covars)} | {' + '.join(self.fevars)}"
else:
Expand Down Expand Up @@ -310,7 +307,6 @@ def compress_data(self):
)

def collect_data(self, data: pd.DataFrame):

rhs = (
self.covariates
+ [f"avg_{cov}_unit" for cov in self.covariates]
Expand All @@ -332,7 +328,6 @@ def collect_data(self, data: pd.DataFrame):
return y, X, n

def estimate(self):

y, X, n = self.collect_data(data=self.df_compressed)
return wls(X, y, n)

Expand Down Expand Up @@ -408,8 +403,259 @@ def bootstrap(self):


################################################################################
class DuckMundlakEventStudy(DuckReg):
def __init__(
self,
db_name: str,
table_name: str,
outcome_var: str,
treatment_col: str,
unit_col: str,
time_col: str,
cluster_col: str,
n_bootstraps: int = 100,
**kwargs,
):
super().__init__(
db_name=db_name,
table_name=table_name,
n_bootstraps=n_bootstraps,
**kwargs,
)
self.table_name = table_name
self.outcome_var = outcome_var
self.treatment_col = treatment_col
self.unit_col = unit_col
self.time_col = time_col
self.num_periods = None
self.cohorts = None
self.time_dummies = None
self.post_treatment_dummies = None
self.transformed_query = None
self.compression_query = None
self.cluster_col = cluster_col

def prepare_data(self):
# create_cohort_and_ever_treated_columns
self.cohort_query = f"""
ALTER TABLE {self.table_name} ADD COLUMN cohort INTEGER;
UPDATE {self.table_name} SET cohort = (
SELECT MIN({self.time_col})
FROM {self.table_name} AS p2
WHERE p2.{self.unit_col} = {self.table_name}.{self.unit_col} AND p2.{self.treatment_col} = 1
);
"""
self.conn.execute(self.cohort_query)
self.ever_treated_query = f"""
UPDATE {self.table_name} SET cohort = NULL WHERE cohort = 2147483647; -- Set to NULL if never treated
ALTER TABLE {self.table_name} ADD COLUMN ever_treated INTEGER;
UPDATE {self.table_name} SET ever_treated = CASE WHEN cohort IS NOT NULL THEN 1 ELSE 0 END;
"""
self.conn.execute(self.ever_treated_query)
# retrieve_num_periods_and_cohorts
self.num_periods = self.conn.execute(
f"SELECT MAX({self.time_col}) FROM {self.table_name}"
).fetchone()[0]
cohorts = self.conn.execute(
f"SELECT DISTINCT cohort FROM {self.table_name} WHERE cohort IS NOT NULL"
).fetchall()
self.cohorts = [row[0] for row in cohorts]
# generate_time_dummies
self.time_dummies = ",\n".join(
[
f"CASE WHEN {self.time_col} = {i} THEN 1 ELSE 0 END AS time_{i}"
for i in range(self.num_periods + 1)
]
)
# generate cohort dummies
cohort_intercepts = []
for cohort in self.cohorts:
cohort_intercepts.append(
f"CASE WHEN cohort = {cohort} THEN 1 ELSE 0 END AS cohort_{cohort}"
)
self.cohort_intercepts = ",\n".join(cohort_intercepts)

# generate_treatment_dummies
treatment_dummies = []
for cohort in self.cohorts:
for i in range(self.num_periods + 1):
treatment_dummies.append(
f"CASE WHEN cohort = {cohort} AND {self.time_col} = {i} THEN 1 ELSE 0 END AS treatment_time_{cohort}_{i}"
)
self.treatment_dummies = ",\n".join(treatment_dummies)

# create_transformed_query
self.design_matrix_query = f"""
CREATE TEMP TABLE transformed_panel_data AS
SELECT
p.{self.unit_col},
p.{self.time_col},
p.{self.treatment_col},
p.{self.outcome_var},
-- Intercept (constant term)
1 AS intercept,
-- cohort intercepts
{self.cohort_intercepts},
-- Time dummies for each period
{self.time_dummies},
-- Treated group interacted with treatment time dummies
{self.treatment_dummies}
FROM
{self.table_name} p;
"""
self.conn.execute(self.design_matrix_query)

def compress_data(self):
self.rhs = f"""
intercept,
{", ".join([f"cohort_{cohort}" for cohort in self.cohorts])},
{", ".join([f"time_{i}" for i in range(self.num_periods + 1)])},
{", ".join([f"treatment_time_{cohort}_{i}" for cohort in self.cohorts for i in range(self.num_periods + 1)])};
"""
self.compression_query = f"""
CREATE TEMP TABLE compressed_panel_data AS
SELECT
{self.rhs.replace(";", "")},
COUNT(*) AS count,
SUM({self.outcome_var}) AS sum_{self.outcome_var}
FROM
transformed_panel_data
GROUP BY
{self.rhs}
"""
self.conn.execute(self.compression_query)
self.df_compressed = self.conn.execute(
"SELECT * FROM compressed_panel_data"
).fetchdf()
self.df_compressed[f"mean_{self.outcome_var}"] = (
self.df_compressed[f"sum_{self.outcome_var}"] / self.df_compressed["count"]
)

def collect_data(self, data):
self._rhs_list = [x.strip().replace(";", "") for x in self.rhs.split(",")]
X = data[self._rhs_list].values
y = data[f"mean_{self.outcome_var}"].values
n = data["count"].values

y = y.reshape(-1, 1) if y.ndim == 1 else y
X = X.reshape(-1, 1) if X.ndim == 1 else X
return y, X, n

def estimate(self):
y, X, n = self.collect_data(data=self.df_compressed)
coef = wls(X, y, n)
res = pd.DataFrame(
{
"est": coef.squeeze(),
},
index=self._rhs_list,
)
cohort_names = [x.split("_")[1] for x in self._rhs_list if "cohort_" in x]
event_study_coefs = {}
for c in cohort_names:
offset = res.filter(regex=f"^cohort_{c}", axis=0).values
event_study_coefs[c] = (
res.filter(regex=f"treatment_time_{c}_", axis=0) + offset
)

return event_study_coefs

def bootstrap(self):
# list all clusters
total_clusters = self.conn.execute(
f"SELECT COUNT(DISTINCT {self.cluster_col}) FROM transformed_panel_data"
).fetchone()[0]
boot_coefs = {str(cohort): [] for cohort in self.cohorts}
# bootstrap loop
for _ in tqdm(range(self.n_bootstraps)):
resampled_clusters = (
self.conn.execute(
f"SELECT UNNEST(ARRAY(SELECT {self.cluster_col} FROM transformed_panel_data ORDER BY RANDOM() LIMIT {total_clusters}))"
)
.fetchdf()
.values.flatten()
.tolist()
)

self.conn.execute(
f"""
CREATE TEMP TABLE resampled_transformed_panel_data AS
SELECT * FROM transformed_panel_data
WHERE {self.cluster_col} IN ({', '.join(map(str, resampled_clusters))})
"""
)

self.conn.execute(
f"""
CREATE TEMP TABLE resampled_compressed_panel_data AS
SELECT
{self.rhs.replace(";", "")},
COUNT(*) AS count,
SUM({self.outcome_var}) AS sum_{self.outcome_var}
FROM
resampled_transformed_panel_data
GROUP BY
{self.rhs.replace(";", "")}
"""
)

df_boot = self.conn.execute(
"SELECT * FROM resampled_compressed_panel_data"
).fetchdf()
df_boot[f"mean_{self.outcome_var}"] = (
df_boot[f"sum_{self.outcome_var}"] / df_boot["count"]
)

y, X, n = self.collect_data(data=df_boot)
coef = wls(X, y, n)
res = pd.DataFrame(
{
"est": coef.squeeze(),
},
index=self._rhs_list,
)
cohort_names = [x.split("_")[1] for x in self._rhs_list if "cohort_" in x]
for c in cohort_names:
offset = res.filter(regex=f"^cohort_{c}", axis=0).values
event_study_coefs = (
res.filter(regex=f"treatment_time_{c}_", axis=0) + offset
)
boot_coefs[c].append(event_study_coefs.values.flatten())

self.conn.execute("DROP TABLE resampled_transformed_panel_data")
self.conn.execute("DROP TABLE resampled_compressed_panel_data")
# Calculate the covariance matrix for each cohort
bootstrap_cov_matrix = {
cohort: np.cov(np.array(coefs).T) for cohort, coefs in boot_coefs.items()
}
return bootstrap_cov_matrix

def estimate_feols(self):
raise NotImplementedError(
"feols solver not implemented for Mundlak event study estimator"
)

def summary(self) -> dict:
"""Summary of event study regression (overrides the parent class method)
Returns:
dict of event study coefficients and their standard errors
"""
if self.n_bootstraps > 0:
summary_tables = {}
for c in self.point_estimate.keys():
point_estimate = self.point_estimate[c]
se = np.sqrt(np.diag(self.vcov[c]))
summary_tables[c] = pd.DataFrame(
np.c_[point_estimate, se],
columns=["point_estimate", "se"],
index=point_estimate.index,
)
return summary_tables
return {"point_estimate": self.point_estimate}


################################################################################
class DuckDoubleDemeaning(DuckReg):
def __init__(
self,
Expand All @@ -429,7 +675,7 @@ def __init__(
table_name=table_name,
seed=seed,
n_bootstraps=n_bootstraps,
**kwargs
**kwargs,
)
self.outcome_var = outcome_var
self.treatment_var = treatment_var
Expand Down Expand Up @@ -494,7 +740,6 @@ def compress_data(self):
)

def collect_data(self, data: pd.DataFrame):

X = data[f"ddot_{self.treatment_var}"].values
X = np.c_[np.ones(X.shape[0]), X]
y = data[f"mean_{self.outcome_var}"].values
Expand Down Expand Up @@ -566,4 +811,4 @@ def bootstrap(self):
return np.cov(boot_coefs.T)


######################################################################
################################################################################
Loading

0 comments on commit 66bf4dd

Please sign in to comment.