Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating Causal Identification module #1166

Merged
merged 53 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
18ccc61
Creating Causal Identification module
cetagostini Nov 4, 2024
6a08373
Pre-commit
cetagostini Nov 4, 2024
9f4af46
Merge branch 'main' into causal_identification
wd60622 Nov 6, 2024
a53b7d7
adding missing libraries
cetagostini Nov 6, 2024
c45e5c1
Merge branch 'main' into causal_identification
cetagostini Nov 6, 2024
8c26976
Merge branch 'main' into causal_identification
cetagostini Nov 12, 2024
7b09ef6
Pushing for push
cetagostini Nov 13, 2024
8d51555
Another random push
cetagostini Nov 14, 2024
171bd10
Final v1 push
cetagostini Nov 16, 2024
4f281a7
Merge branch 'main' into causal_identification
cetagostini Nov 16, 2024
a77d871
Adding pre-commit
cetagostini Nov 16, 2024
4299b95
Adding to index
cetagostini Nov 16, 2024
d5effba
Functions in the notebook
cetagostini Nov 16, 2024
a81aee6
More adjustment in notebook functions
cetagostini Nov 16, 2024
8ce8d56
Error on description
cetagostini Nov 17, 2024
e4e09a9
Merge branch 'main' into causal_identification
cetagostini Nov 21, 2024
3c4f5c7
Requested changes
cetagostini Nov 21, 2024
abd01d3
Trying to solve dependency error test.
cetagostini Nov 25, 2024
49216d8
Solving errors
cetagostini Nov 25, 2024
dcb52d6
Pydantic
cetagostini Nov 25, 2024
dab6784
Merge branch 'main' into causal_identification
cetagostini Nov 27, 2024
3fb3cc1
add support for save and load
wd60622 Nov 28, 2024
fb886b0
support for backwards compat
wd60622 Nov 28, 2024
f044ca6
Merge branch 'main' into causal_identification
cetagostini Nov 28, 2024
3856be0
A fancy commit
cetagostini Nov 28, 2024
b31a86f
Modify
cetagostini Nov 28, 2024
ca9d8fb
Merge branch 'main' into causal_identification
cetagostini Dec 2, 2024
3040b27
Merge branch 'main' into causal_identification
wd60622 Dec 5, 2024
330c64f
Merge branch 'main' into causal_identification
cetagostini Dec 16, 2024
bc6d2ba
Merge branch 'main' into causal_identification
cetagostini Dec 16, 2024
f308243
Notebook adjustments
cetagostini Dec 16, 2024
6858494
Merge branch 'causal_identification' of https://github.com/pymc-labs/…
cetagostini Dec 16, 2024
f6dc7cf
Remove model builder needs
cetagostini Dec 16, 2024
891d402
Merge branch 'main' into causal_identification
juanitorduz Dec 19, 2024
95eac12
Creating test for causal module
cetagostini Dec 24, 2024
8ca5777
Merge branch 'main' into causal_identification
cetagostini Dec 24, 2024
9f997e2
Updating notebook.
cetagostini Dec 24, 2024
2ccd75c
Merge branch 'causal_identification' of https://github.com/pymc-labs/…
cetagostini Dec 24, 2024
4b8c122
Merge branch 'main' into causal_identification
wd60622 Dec 25, 2024
f1d12f7
Merge branch 'main' into causal_identification
juanitorduz Dec 28, 2024
a70d340
Merge branch 'main' into causal_identification
cetagostini Dec 30, 2024
fb9993e
Merge branch 'main' into causal_identification
juanitorduz Jan 2, 2025
51b9bba
Merge branch 'main' into causal_identification
cetagostini Jan 17, 2025
9529fb6
Feedback from Juan WIP
cetagostini Jan 17, 2025
cfa1bb2
pre-commit
cetagostini Jan 17, 2025
57404d6
Changes requested changes
cetagostini Jan 19, 2025
3d7ccb8
Merge branch 'main' into causal_identification
cetagostini Jan 19, 2025
eecdfe7
Merge branch 'main' into causal_identification
cetagostini Jan 20, 2025
c5f2b08
Merge branch 'main' into causal_identification
juanitorduz Jan 20, 2025
2fbe334
Merge branch 'main' into causal_identification
juanitorduz Jan 20, 2025
9eba9c5
Merge branch 'main' into causal_identification
cetagostini Jan 20, 2025
d60fde5
header
juanitorduz Jan 21, 2025
db1d666
Merge branch 'main' into causal_identification
juanitorduz Jan 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ jobs:
run: |
sudo apt-get install graphviz
pip install -e .[docs]
pip install -e .[test]
pip install -e .[test,dag]
- name: Run notebooks
run: make run_notebooks
3 changes: 3 additions & 0 deletions docs/source/notebooks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ Here you will find a collection of examples and how-to guides for using PyMC-Mar
:caption: MMMs
:maxdepth: 1

mmm/mmm_example
mmm/mmm_budget_allocation_example
mmm/mmm_allocation_assessment
mmm/mmm_budget_allocation_example
mmm/mmm_case_study
mmm/mmm_causal_identification
mmm/mmm_components
mmm/mmm_counterfactuals
mmm/mmm_evaluation
Expand Down
Binary file added docs/source/notebooks/mmm/causal_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,940 changes: 3,940 additions & 0 deletions docs/source/notebooks/mmm/mmm_causal_identification.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ dependencies:
- sphinx-design
- watermark
- typing
- networkx
- dowhy
# lint
- mypy
- pandas-stubs
Expand Down
158 changes: 158 additions & 0 deletions pymc_marketing/mmm/causal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Causal identification class."""

import warnings

import pandas as pd

try:
from dowhy import CausalModel
except ImportError:

Check warning on line 22 in pymc_marketing/mmm/causal.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/causal.py#L22

Added line #L22 was not covered by tests

class LazyCausalModel:

Check warning on line 24 in pymc_marketing/mmm/causal.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/causal.py#L24

Added line #L24 was not covered by tests
"""Lazy import of dowhy's CausalModel."""

def __init__(self, *args, **kwargs):
msg = (

Check warning on line 28 in pymc_marketing/mmm/causal.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/causal.py#L27-L28

Added lines #L27 - L28 were not covered by tests
"To use Causal Graph functionality, please install the optional dependencies with: "
"pip install pymc-marketing[dag]"
)
raise ImportError(msg)

Check warning on line 32 in pymc_marketing/mmm/causal.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/causal.py#L32

Added line #L32 was not covered by tests

CausalModel = LazyCausalModel

Check warning on line 34 in pymc_marketing/mmm/causal.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/causal.py#L34

Added line #L34 was not covered by tests


class CausalGraphModel:
"""Represent a causal model based on a Directed Acyclic Graph (DAG).

Provides methods to analyze causal relationships and determine the minimal adjustment set
for backdoor adjustment between treatment and outcome variables.
Comment on lines +40 to +41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes, external regressors are not in the minimal set but help decreasing variance; see https://matheusfacure.github.io/python-causality-handbook/07-Beyond-Confounders.html#good-controls

Concretely:

Anytime we have a control that is a good predictor of the outcome, even if it is not a confounder, adding it to our model is a good idea.

So I am hesitant to remove, for example seasonality, if it is not in the minimal set. WDYT?

Copy link
Contributor Author

@cetagostini cetagostini Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my current understanding, this is correct under the assumption that the variable is conditionally independent of the treatment given the adjustment set and does not introduce bias. If thats the case, then including strong predictors of the outcome can reduce the residual variance in the outcome, leading to more precise estimates of the treatment effect on other nodes.

The example where it is shown with email and credit limit works, the other nodes do not generate a problem, but many times we have situations where the DAG is more complicated and if you do not remove it, you introduce bias. I think that assuming that whoever uses the module will always have an easy situation where they can leave everything would no longer be optimal.

As you say, it reduces the variance but by adding or not seasonal, if seasonal is not part of the minimum adjustment set then the mean estimation of the nodes of interest would not change. However, not removing it assuming it will be independent is a huge assumption that could affect the estimation of the node of interest, if thats not the case for a specific user.

@juanitorduz


Parameters
----------
causal_model : CausalModel
An instance of dowhy's CausalModel, representing the causal graph and its relationships.
treatment : list[str]
A list of treatment variable names.
outcome : str
The outcome variable name.

References
----------
.. [1] https://github.com/microsoft/dowhy
"""

def __init__(
self, causal_model: CausalModel, treatment: list[str] | tuple[str], outcome: str
) -> None:
self.causal_model = causal_model
self.treatment = treatment
self.outcome = outcome

@classmethod
def build_graphical_model(
cls, graph: str, treatment: list[str] | tuple[str], outcome: str
) -> "CausalGraphModel":
"""Create a CausalGraphModel from a string representation of a graph.

Parameters
----------
graph : str
A string representation of the graph (e.g., String in DOT format).
treatment : list[str]
A list of treatment variable names.
outcome : str
The outcome variable name.

Returns
-------
CausalGraphModel
An instance of CausalGraphModel constructed from the given graph string.
"""
causal_model = CausalModel(
data=pd.DataFrame(), graph=graph, treatment=treatment, outcome=outcome
)
return cls(causal_model, treatment, outcome)

def get_backdoor_paths(self) -> list[list[str]]:
"""Find all backdoor paths between the combined treatment and outcome variables.

Returns
-------
list[list[str]]
A list of backdoor paths, where each path is represented as a list of variable names.

References
----------
.. [1] Causal Inference in Statistics: A Primer
By Judea Pearl, Madelyn Glymour, Nicholas P. Jewell · 2016
"""
# Use DoWhy's internal method to get backdoor paths for all treatments combined
return self.causal_model._graph.get_backdoor_paths(
nodes1=self.treatment, nodes2=[self.outcome]
)

def get_unique_adjustment_nodes(self) -> list[str]:
"""Compute the minimal adjustment set required for backdoor adjustment across all treatments.

Returns
-------
list[str]
A list of unique adjustment variables needed to block all backdoor paths.
"""
paths = self.get_backdoor_paths()
# Flatten paths and exclude treatments and outcome from adjustment set
adjustment_nodes = set(
node
for path in paths
for node in path
if node not in self.treatment and node != self.outcome
)
return list(adjustment_nodes)

def compute_adjustment_sets(
self,
channel_columns: list[str] | tuple[str],
control_columns: list[str] | None = None,
) -> list[str] | None:
"""Compute minimal adjustment sets and handle warnings."""
channel_columns = list(channel_columns)
if control_columns is None:
return control_columns

self.adjustment_set = self.get_unique_adjustment_nodes()

common_controls = set(control_columns).intersection(self.adjustment_set)
unique_controls = set(control_columns) - set(self.adjustment_set)

if unique_controls:
warnings.warn(
f"Columns {unique_controls} are not in the adjustment set. Controls are being modified.",
stacklevel=2,
)

control_columns = list(common_controls - set(channel_columns))

self.minimal_adjustment_set = control_columns + list(channel_columns)
Comment on lines +140 to +148
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hesitant on this step because my comment on variance reduction above. Maybe we can have an additional parameter, something like minimal or maximal set . WDYD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding variance reduction, I replied. On the other hand, indeed we could implement something like that, should not be complicated, and I'm happy to work on it.

Do we want this for initial release? Sounds like we could add later on, I think its well understood that minimal its minimal adjustment, only the necessary is needed. Maximal could come on a following PR to allow largest possible set of variables that can be adjusted for without introducing bias.

@juanitorduz

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @carlosagostini . I think we can provide a minimal set as you suggest but add a user warning that this feature is experimental and that the minimal set does not always lead to the best model, as we could be missing opportunities to reduce the variance of the estimate. We can release this initial experimental feature and collect feedback from the users, WDYT?


for column in self.adjustment_set:
if column not in control_columns and column not in channel_columns:
warnings.warn(
f"""Column {column} in adjustment set not found in data.
Not controlling for this may induce bias in treatment effect estimates.""",
stacklevel=2,
)

return control_columns
55 changes: 55 additions & 0 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from pymc_marketing.hsgp_kwargs import HSGPKwargs
from pymc_marketing.mmm.base import BaseValidateMMM
from pymc_marketing.mmm.causal import CausalGraphModel
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
adstock_from_dict,
Expand Down Expand Up @@ -115,6 +116,17 @@ def __init__(
adstock_first: bool = Field(
True, description="Whether to apply adstock first."
),
dag: str | None = Field(
None,
description="Optional DAG provided as a string Dot format for causal identification.",
),
treatment_nodes: list[str] | tuple[str] | None = Field(
None,
description="Column names of the variables of interest to identify causal effects on outcome.",
),
outcome_node: str | None = Field(
None, description="Name of the outcome variable."
),
Comment on lines +119 to +129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where I would like to discuss the API. Our MMM class is already a huge monolith of many components, and I would like us to start modularizing more or even making it a subclass.

For instance, we can keep BaseMMM as it is and have an additional

CausalMMM(BaseMMM), and if people want to use this class, they need to install DoWhy. I am personally against adding DoWhy as a required dependency, as in my experience, they sometimes hard-pin soma packages and can make it harder to resolve dependencies. WDYT?

Thoughts @wd60622 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea. I can work quickly on it, will wait for William comments 🙌🏻 Probably will have a meeting with him on Tuesday!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the dependencies are the issue then I think we can get away with having the dowhy and networkx only be required if the dag is specified. That would make models with backward compat not needing to add the new dependencies for the same model. Would only checking for these depends in the case of using this functionality solve your concerns? @juanitorduz

I think going the route of subclassing could just add more code to manage 😢

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think going the route of subclassing could just add more code to manage 😢

true ... what would be your suggestion @wd60622 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think less code to manage its better and users still import the same MMM class. The amount of code lines is only 20, I don't see it as something crazy. Whats your opinion?

) -> None:
"""Define the constructor method.

Expand Down Expand Up @@ -151,6 +163,12 @@ def __init__(
Number of Fourier modes to model yearly seasonality, by default None.
adstock_first : bool, optional
Whether to apply adstock first, by default True.
dag : Optional[str], optional
Optional DAG provided as a string Dot format for causal modeling, by default None.
treatment_nodes : Optional[list[str]], optional
Column names of the variables of interest to identify causal effects on outcome.
outcome_node : Optional[str], optional
Name of the outcome variable, by default None.
"""
self.control_columns = control_columns
self.time_varying_intercept = time_varying_intercept
Expand Down Expand Up @@ -180,6 +198,37 @@ def __init__(
)

self.yearly_seasonality = yearly_seasonality

self.dag = dag
self.treatment_nodes = treatment_nodes
self.outcome_node = outcome_node

# Initialize causal graph if provided
if self.dag is not None and self.outcome_node is not None:
if self.treatment_nodes is None:
self.treatment_nodes = self.channel_columns
warnings.warn(
"No treatment nodes provided, using channel columns as treatment nodes.",
stacklevel=2,
)
self.causal_graphical_model = CausalGraphModel.build_graphical_model(
graph=self.dag,
treatment=self.treatment_nodes,
outcome=self.outcome_node,
)

self.control_columns = self.causal_graphical_model.compute_adjustment_sets(
control_columns=self.control_columns,
channel_columns=self.channel_columns,
)

if "yearly_seasonality" not in self.causal_graphical_model.adjustment_set:
warnings.warn(
"Yearly seasonality excluded as it's not required for adjustment.",
stacklevel=2,
)
self.yearly_seasonality = None

if self.yearly_seasonality is not None:
self.yearly_fourier = YearlyFourier(
n_order=self.yearly_seasonality,
Expand Down Expand Up @@ -305,6 +354,9 @@ def create_idata_attrs(self) -> dict[str, str]:
attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
attrs["time_varying_media"] = json.dumps(self.time_varying_media)
attrs["dag"] = json.dumps(self.dag)
attrs["treatment_nodes"] = json.dumps(self.treatment_nodes)
attrs["outcome_node"] = json.dumps(self.outcome_node)

return attrs

Expand Down Expand Up @@ -680,6 +732,9 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
"time_varying_media": json.loads(attrs.get("time_varying_media", "false")),
"validate_data": json.loads(attrs["validate_data"]),
"sampler_config": json.loads(attrs["sampler_config"]),
"dag": json.loads(attrs.get("dag", "null")),
"treatment_nodes": json.loads(attrs.get("treatment_nodes", "null")),
"outcome_node": json.loads(attrs.get("outcome_node", "null")),
}

def _data_setter(
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ dependencies = [
]

[project.optional-dependencies]
dag = [
"dowhy",
"networkx",
]
docs = [
"blackjax",
"fastprogress",
Expand All @@ -63,6 +67,8 @@ docs = [
"sphinxext-opengraph",
"watermark",
"mlflow>=2.0.0",
"networkx",
"dowhy",
]
lint = ["mypy", "pandas-stubs", "pre-commit>=2.19.0", "ruff>=0.1.4"]
test = [
Expand All @@ -78,6 +84,8 @@ test = [
"pytest-mock>=3.14.0",
"pytest>=7.0.1",
"mlflow>=2.0.0",
"networkx",
"dowhy",
]

[tool.hatch.build.targets.sdist]
Expand Down
Loading
Loading