-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Creating Causal Identification module #1166
Changes from 51 commits
18ccc61
6a08373
9f4af46
a53b7d7
c45e5c1
8c26976
7b09ef6
8d51555
171bd10
4f281a7
a77d871
4299b95
d5effba
a81aee6
8ce8d56
e4e09a9
3c4f5c7
abd01d3
49216d8
dcb52d6
dab6784
3fb3cc1
fb886b0
f044ca6
3856be0
b31a86f
ca9d8fb
3040b27
330c64f
bc6d2ba
f308243
6858494
f6dc7cf
891d402
95eac12
8ca5777
9f997e2
2ccd75c
4b8c122
f1d12f7
a70d340
fb9993e
51b9bba
9529fb6
cfa1bb2
57404d6
3d7ccb8
eecdfe7
c5f2b08
2fbe334
9eba9c5
d60fde5
db1d666
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,8 @@ dependencies: | |
- sphinx-design | ||
- watermark | ||
- typing | ||
- networkx | ||
- dowhy | ||
# lint | ||
- mypy | ||
- pandas-stubs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright 2025 The PyMC Labs Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Causal identification class.""" | ||
|
||
import warnings | ||
|
||
import pandas as pd | ||
|
||
try: | ||
from dowhy import CausalModel | ||
except ImportError: | ||
|
||
class LazyCausalModel: | ||
"""Lazy import of dowhy's CausalModel.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
msg = ( | ||
"To use Causal Graph functionality, please install the optional dependencies with: " | ||
"pip install pymc-marketing[dag]" | ||
) | ||
raise ImportError(msg) | ||
|
||
CausalModel = LazyCausalModel | ||
|
||
|
||
class CausalGraphModel: | ||
"""Represent a causal model based on a Directed Acyclic Graph (DAG). | ||
|
||
Provides methods to analyze causal relationships and determine the minimal adjustment set | ||
for backdoor adjustment between treatment and outcome variables. | ||
|
||
Parameters | ||
---------- | ||
causal_model : CausalModel | ||
An instance of dowhy's CausalModel, representing the causal graph and its relationships. | ||
treatment : list[str] | ||
A list of treatment variable names. | ||
outcome : str | ||
The outcome variable name. | ||
juanitorduz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
References | ||
---------- | ||
.. [1] https://github.com/microsoft/dowhy | ||
""" | ||
|
||
def __init__( | ||
self, causal_model: CausalModel, treatment: list[str] | tuple[str], outcome: str | ||
cetagostini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
self.causal_model = causal_model | ||
self.treatment = treatment | ||
self.outcome = outcome | ||
|
||
@classmethod | ||
def build_graphical_model( | ||
cls, graph: str, treatment: list[str] | tuple[str], outcome: str | ||
) -> "CausalGraphModel": | ||
"""Create a CausalGraphModel from a string representation of a graph. | ||
|
||
Parameters | ||
---------- | ||
graph : str | ||
A string representation of the graph (e.g., String in DOT format). | ||
treatment : list[str] | ||
A list of treatment variable names. | ||
outcome : str | ||
The outcome variable name. | ||
|
||
Returns | ||
------- | ||
CausalGraphModel | ||
An instance of CausalGraphModel constructed from the given graph string. | ||
""" | ||
causal_model = CausalModel( | ||
data=pd.DataFrame(), graph=graph, treatment=treatment, outcome=outcome | ||
) | ||
return cls(causal_model, treatment, outcome) | ||
|
||
def get_backdoor_paths(self) -> list[list[str]]: | ||
"""Find all backdoor paths between the combined treatment and outcome variables. | ||
|
||
Returns | ||
------- | ||
list[list[str]] | ||
A list of backdoor paths, where each path is represented as a list of variable names. | ||
|
||
References | ||
---------- | ||
.. [1] Causal Inference in Statistics: A Primer | ||
By Judea Pearl, Madelyn Glymour, Nicholas P. Jewell · 2016 | ||
""" | ||
# Use DoWhy's internal method to get backdoor paths for all treatments combined | ||
return self.causal_model._graph.get_backdoor_paths( | ||
nodes1=self.treatment, nodes2=[self.outcome] | ||
) | ||
|
||
def get_unique_adjustment_nodes(self) -> list[str]: | ||
"""Compute the minimal adjustment set required for backdoor adjustment across all treatments. | ||
juanitorduz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns | ||
------- | ||
list[str] | ||
A list of unique adjustment variables needed to block all backdoor paths. | ||
""" | ||
paths = self.get_backdoor_paths() | ||
# Flatten paths and exclude treatments and outcome from adjustment set | ||
adjustment_nodes = set( | ||
node | ||
for path in paths | ||
for node in path | ||
if node not in self.treatment and node != self.outcome | ||
) | ||
return list(adjustment_nodes) | ||
|
||
def compute_adjustment_sets( | ||
self, | ||
channel_columns: list[str] | tuple[str], | ||
control_columns: list[str] | None = None, | ||
) -> list[str] | None: | ||
"""Compute minimal adjustment sets and handle warnings.""" | ||
channel_columns = list(channel_columns) | ||
if control_columns is None: | ||
return control_columns | ||
|
||
self.adjustment_set = self.get_unique_adjustment_nodes() | ||
|
||
common_controls = set(control_columns).intersection(self.adjustment_set) | ||
unique_controls = set(control_columns) - set(self.adjustment_set) | ||
|
||
if unique_controls: | ||
warnings.warn( | ||
f"Columns {unique_controls} are not in the adjustment set. Controls are being modified.", | ||
stacklevel=2, | ||
) | ||
|
||
control_columns = list(common_controls - set(channel_columns)) | ||
|
||
self.minimal_adjustment_set = control_columns + list(channel_columns) | ||
Comment on lines
+140
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am hesitant on this step because my comment on variance reduction above. Maybe we can have an additional parameter, something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding variance reduction, I replied. On the other hand, indeed we could implement something like that, should not be complicated, and I'm happy to work on it. Do we want this for initial release? Sounds like we could add later on, I think its well understood that minimal its minimal adjustment, only the necessary is needed. Maximal could come on a following PR to allow largest possible set of variables that can be adjusted for without introducing bias. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @carlosagostini . I think we can provide a minimal set as you suggest but add a user warning that this feature is experimental and that the minimal set does not always lead to the best model, as we could be missing opportunities to reduce the variance of the estimate. We can release this initial experimental feature and collect feedback from the users, WDYT? |
||
|
||
for column in self.adjustment_set: | ||
if column not in control_columns and column not in channel_columns: | ||
warnings.warn( | ||
f"""Column {column} in adjustment set not found in data. | ||
Not controlling for this may induce bias in treatment effect estimates.""", | ||
stacklevel=2, | ||
) | ||
|
||
return control_columns |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
|
||
from pymc_marketing.hsgp_kwargs import HSGPKwargs | ||
from pymc_marketing.mmm.base import BaseValidateMMM | ||
from pymc_marketing.mmm.causal import CausalGraphModel | ||
from pymc_marketing.mmm.components.adstock import ( | ||
AdstockTransformation, | ||
adstock_from_dict, | ||
|
@@ -115,6 +116,17 @@ def __init__( | |
adstock_first: bool = Field( | ||
True, description="Whether to apply adstock first." | ||
), | ||
dag: str | None = Field( | ||
None, | ||
description="Optional DAG provided as a string Dot format for causal identification.", | ||
), | ||
treatment_nodes: list[str] | tuple[str] | None = Field( | ||
None, | ||
description="Column names of the variables of interest to identify causal effects on outcome.", | ||
), | ||
outcome_node: str | None = Field( | ||
None, description="Name of the outcome variable." | ||
), | ||
Comment on lines
+119
to
+129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is where I would like to discuss the API. Our MMM class is already a huge monolith of many components, and I would like us to start modularizing more or even making it a subclass. For instance, we can keep
Thoughts @wd60622 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea. I can work quickly on it, will wait for William comments 🙌🏻 Probably will have a meeting with him on Tuesday! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the dependencies are the issue then I think we can get away with having the dowhy and networkx only be required if the dag is specified. That would make models with backward compat not needing to add the new dependencies for the same model. Would only checking for these depends in the case of using this functionality solve your concerns? @juanitorduz I think going the route of subclassing could just add more code to manage 😢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
true ... what would be your suggestion @wd60622 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think less code to manage its better and users still import the same MMM class. The amount of code lines is only 20, I don't see it as something crazy. Whats your opinion?
cetagostini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
"""Define the constructor method. | ||
|
||
|
@@ -151,6 +163,12 @@ def __init__( | |
Number of Fourier modes to model yearly seasonality, by default None. | ||
adstock_first : bool, optional | ||
Whether to apply adstock first, by default True. | ||
dag : Optional[str], optional | ||
Optional DAG provided as a string Dot format for causal modeling, by default None. | ||
treatment_nodes : Optional[list[str]], optional | ||
Column names of the variables of interest to identify causal effects on outcome. | ||
outcome_node : Optional[str], optional | ||
Name of the outcome variable, by default None. | ||
""" | ||
self.control_columns = control_columns | ||
self.time_varying_intercept = time_varying_intercept | ||
|
@@ -180,6 +198,37 @@ def __init__( | |
) | ||
|
||
self.yearly_seasonality = yearly_seasonality | ||
|
||
self.dag = dag | ||
self.treatment_nodes = treatment_nodes | ||
self.outcome_node = outcome_node | ||
|
||
# Initialize causal graph if provided | ||
if self.dag is not None and self.outcome_node is not None: | ||
if self.treatment_nodes is None: | ||
self.treatment_nodes = self.channel_columns | ||
warnings.warn( | ||
"No treatment nodes provided, using channel columns as treatment nodes.", | ||
stacklevel=2, | ||
) | ||
self.causal_graphical_model = CausalGraphModel.build_graphical_model( | ||
graph=self.dag, | ||
treatment=self.treatment_nodes, | ||
outcome=self.outcome_node, | ||
) | ||
|
||
self.control_columns = self.causal_graphical_model.compute_adjustment_sets( | ||
cetagostini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
control_columns=self.control_columns, | ||
channel_columns=self.channel_columns, | ||
) | ||
|
||
if "yearly_seasonality" not in self.causal_graphical_model.adjustment_set: | ||
warnings.warn( | ||
"Yearly seasonality excluded as it's not required for adjustment.", | ||
stacklevel=2, | ||
) | ||
juanitorduz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.yearly_seasonality = None | ||
|
||
if self.yearly_seasonality is not None: | ||
self.yearly_fourier = YearlyFourier( | ||
n_order=self.yearly_seasonality, | ||
|
@@ -305,6 +354,9 @@ def create_idata_attrs(self) -> dict[str, str]: | |
attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality) | ||
attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept) | ||
attrs["time_varying_media"] = json.dumps(self.time_varying_media) | ||
attrs["dag"] = json.dumps(self.dag) | ||
attrs["treatment_nodes"] = json.dumps(self.treatment_nodes) | ||
attrs["outcome_node"] = json.dumps(self.outcome_node) | ||
|
||
return attrs | ||
|
||
|
@@ -680,6 +732,9 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: | |
"time_varying_media": json.loads(attrs.get("time_varying_media", "false")), | ||
"validate_data": json.loads(attrs["validate_data"]), | ||
"sampler_config": json.loads(attrs["sampler_config"]), | ||
"dag": json.loads(attrs.get("dag", "null")), | ||
"treatment_nodes": json.loads(attrs.get("treatment_nodes", "null")), | ||
"outcome_node": json.loads(attrs.get("outcome_node", "null")), | ||
cetagostini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
def _data_setter( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes, external regressors are not in the minimal set but help decreasing variance; see https://matheusfacure.github.io/python-causality-handbook/07-Beyond-Confounders.html#good-controls
Concretely:
So I am hesitant to remove, for example seasonality, if it is not in the minimal set. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on my current understanding, this is correct under the assumption that the variable is conditionally independent of the treatment given the adjustment set and does not introduce bias. If thats the case, then including strong predictors of the outcome can reduce the residual variance in the outcome, leading to more precise estimates of the treatment effect on other nodes.
The example where it is shown with email and credit limit works, the other nodes do not generate a problem, but many times we have situations where the DAG is more complicated and if you do not remove it, you introduce bias. I think that assuming that whoever uses the module will always have an easy situation where they can leave everything would no longer be optimal.
As you say, it reduces the variance but by adding or not seasonal, if seasonal is not part of the minimum adjustment set then the mean estimation of the nodes of interest would not change. However, not removing it assuming it will be independent is a huge assumption that could affect the estimation of the node of interest, if thats not the case for a specific user.
@juanitorduz