From 8cfdcfafe9f3b961c2bc66c2d30980baa78db5cc Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 24 Sep 2024 17:54:47 +0100 Subject: [PATCH] Rmspe test stat (#24) * RMSPE WIP * Rmspe test stat (#22) * Minor change in README to fix guidance for developers (#18) * Noise transform (#19) * Add noise transformation that apply perturbations on treatment * Formatting * Add docstring * Fix linting * Add tests to check perturbation impact and randomness over timepoints * bump version (#20) * Initial implementation of RMSPE * Add TestResultFrame parent class for test results * Add test for RMSPE * Add doc string * Fix linting * Update src/causal_validation/validation/rmspe.py Co-authored-by: Thomas Pinder * Fix typo --------- Co-authored-by: Thomas Pinder --------- Co-authored-by: Thomas Pinder Co-authored-by: Semih Akbayrak --- docs/examples/placebo_test.ipynb | 16 ++ src/causal_validation/data.py | 13 +- src/causal_validation/models.py | 42 ++++- src/causal_validation/transforms/base.py | 8 +- src/causal_validation/types.py | 2 + src/causal_validation/validation/placebo.py | 38 ++-- src/causal_validation/validation/rmspe.py | 133 ++++++++++++++ src/causal_validation/validation/testing.py | 107 +++++++++++ tests/test_causal_validation/test_models.py | 16 +- .../test_validation/test_placebo.py | 2 + .../test_validation/test_rmspe.py | 169 ++++++++++++++++++ 11 files changed, 512 insertions(+), 34 deletions(-) create mode 100644 src/causal_validation/validation/rmspe.py create mode 100644 src/causal_validation/validation/testing.py create mode 100644 tests/test_causal_validation/test_validation/test_rmspe.py diff --git a/docs/examples/placebo_test.ipynb b/docs/examples/placebo_test.ipynb index 511008e..e784193 100644 --- a/docs/examples/placebo_test.ipynb +++ b/docs/examples/placebo_test.ipynb @@ -207,6 +207,22 @@ "datasets = DatasetContainer([data, complex_data], names=[\"Simple\", \"Complex\"])\n", "PlaceboTest([model, did_model], datasets).execute().summary()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index a25a94c..057acfa 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -27,6 +27,7 @@ class Dataset: yte: Float[np.ndarray, "M 1"] _start_date: dt.date counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None + synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None _name: str = None def to_df( @@ -151,7 +152,13 @@ def drop_unit(self, idx: int) -> Dataset: Xtr = np.delete(self.Xtr, [idx], axis=1) Xte = np.delete(self.Xte, [idx], axis=1) return Dataset( - Xtr, Xte, self.ytr, self.yte, self._start_date, self.counterfactual + Xtr, + Xte, + self.ytr, + self.yte, + self._start_date, + self.counterfactual, + self.synthetic, ) def to_placebo_data(self, to_treat_idx: int) -> Dataset: @@ -204,4 +211,6 @@ def reassign_treatment( ) -> Dataset: Xtr = data.Xtr Xte = data.Xte - return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + return Dataset( + Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + ) diff --git a/src/causal_validation/models.py b/src/causal_validation/models.py index d8b5196..499ffe1 100644 --- a/src/causal_validation/models.py +++ b/src/causal_validation/models.py @@ -1,17 +1,29 @@ from dataclasses import dataclass import typing as tp +from azcausal.core.effect import Effect from azcausal.core.error import Error from azcausal.core.estimator import Estimator -from azcausal.core.result import Result +from azcausal.core.result import Result as _Result +from jaxtyping import Float from causal_validation.data import Dataset +from causal_validation.types import NPArray + + +@dataclass +class Result: + effect: Effect + counterfactual: Float[NPArray, "N 1"] + synthetic: Float[NPArray, "N 1"] + observed: Float[NPArray, "N 1"] @dataclass class AZCausalWrapper: model: Estimator error_estimator: tp.Optional[Error] = None + _az_result: _Result = None def __post_init__(self): self._model_name = self.model.__class__.__name__ @@ -21,4 +33,30 @@ def __call__(self, data: Dataset, **kwargs) -> Result: result = self.model.fit(panel, **kwargs) if self.error_estimator: self.model.error(result, self.error_estimator) - return result + self._az_result = result + + res = Result( + effect=result.effect, + counterfactual=self.counterfactual, + synthetic=self.synthetic, + observed=self.observed, + ) + return res + + @property + def counterfactual(self) -> Float[NPArray, "N 1"]: + df = self._az_result.effect.by_time + c_factual = df.loc[:, "CF"].values.reshape(-1, 1) + return c_factual + + @property + def synthetic(self) -> Float[NPArray, "N 1"]: + df = self._az_result.effect.by_time + synth_control = df.loc[:, "C"].values.reshape(-1, 1) + return synth_control + + @property + def observed(self) -> Float[NPArray, "N 1"]: + df = self._az_result.effect.by_time + treated = df.loc[:, "T"].values.reshape(-1, 1) + return treated diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py index ea15109..6ef7a97 100644 --- a/src/causal_validation/transforms/base.py +++ b/src/causal_validation/transforms/base.py @@ -74,7 +74,9 @@ def apply_values( ytr = ytr + pre_intervention_vals[:, :1] Xte = Xte + post_intervention_vals[:, 1:] yte = yte + post_intervention_vals[:, :1] - return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + return Dataset( + Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + ) @dataclass(kw_only=True) @@ -91,4 +93,6 @@ def apply_values( ytr = ytr * pre_intervention_vals Xte = Xte * post_intervention_vals yte = yte * post_intervention_vals - return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + return Dataset( + Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + ) diff --git a/src/causal_validation/types.py b/src/causal_validation/types.py index dd7e124..5ea3505 100644 --- a/src/causal_validation/types.py +++ b/src/causal_validation/types.py @@ -1,5 +1,6 @@ import typing as tp +import numpy as np from scipy.stats._distn_infrastructure import ( rv_continuous, rv_discrete, @@ -10,3 +11,4 @@ InterventionTypes = tp.Literal["pre-intervention", "post-intervention", "both"] RandomVariable = tp.Union[rv_continuous, rv_discrete] Number = tp.Union[float, int] +NPArray = np.ndarray diff --git a/src/causal_validation/validation/placebo.py b/src/causal_validation/validation/placebo.py index dd70909..b8f7c36 100644 --- a/src/causal_validation/validation/placebo.py +++ b/src/causal_validation/validation/placebo.py @@ -9,20 +9,26 @@ Column, DataFrameSchema, ) -from rich import box from rich.progress import ( Progress, ProgressBar, track, ) -from rich.table import Table from scipy.stats import ttest_1samp +from tqdm import ( + tqdm, + trange, +) from causal_validation.data import ( Dataset, DatasetContainer, ) -from causal_validation.models import AZCausalWrapper +from causal_validation.models import ( + AZCausalWrapper, + Result, +) +from causal_validation.validation.testing import TestResultFrame PlaceboSchema = DataFrameSchema( { @@ -39,13 +45,13 @@ @dataclass -class PlaceboTestResult: - effects: tp.Dict[tp.Tuple[str, str], tp.List[Effect]] +class PlaceboTestResult(TestResultFrame): + effects: tp.Dict[tp.Tuple[str, str], tp.List[Result]] def _model_to_df( - self, model_name: str, dataset_name: str, effects: tp.List[Effect] + self, model_name: str, dataset_name: str, effects: tp.List[Result] ) -> pd.DataFrame: - _effects = [effect.value for effect in effects] + _effects = [e.effect.percentage().value for e in effects] _n_effects = len(_effects) expected_effect = np.mean(_effects) stddev_effect = np.std(_effects) @@ -71,21 +77,6 @@ def to_df(self) -> pd.DataFrame: PlaceboSchema.validate(df) return df - def summary(self, precision: int = 4) -> Table: - table = Table(show_header=True, box=box.MARKDOWN) - df = self.to_df() - numeric_cols = df.select_dtypes(include=[np.number]) - df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision) - - for column in df.columns: - table.add_column(str(column), style="magenta") - - for _, value_list in enumerate(df.values.tolist()): - row = [str(x) for x in value_list] - table.add_row(*row) - - return table - @dataclass class PlaceboTest: @@ -109,7 +100,7 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult: datasets = self.dataset_dict n_datasets = len(datasets) n_control = sum([d.n_units for d in datasets.values()]) - with Progress() as progress: + with Progress(disable=not verbose) as progress: model_task = progress.add_task( "[red]Models", total=len(self.models), visible=verbose ) @@ -130,7 +121,6 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult: progress.update(unit_task, advance=1) placebo_data = dataset.to_placebo_data(i) result = model(placebo_data) - result = result.effect.percentage() model_result.append(result) results[(model._model_name, data_name)] = model_result return PlaceboTestResult(effects=results) diff --git a/src/causal_validation/validation/rmspe.py b/src/causal_validation/validation/rmspe.py new file mode 100644 index 0000000..6b541ff --- /dev/null +++ b/src/causal_validation/validation/rmspe.py @@ -0,0 +1,133 @@ +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np +import pandas as pd +from pandera import ( + Check, + Column, + DataFrameSchema, +) +from rich import box +from rich.progress import ( + Progress, + ProgressBar, + track, +) + +from causal_validation.validation.placebo import PlaceboTest +from causal_validation.validation.testing import ( + RMSPETestStatistic, + TestResult, + TestResultFrame, +) + +RMSPESchema = DataFrameSchema( + { + "Model": Column(str), + "Dataset": Column(str), + "Test statistic": Column(float, coerce=True), + "p-value": Column( + float, + checks=[ + Check.greater_than_or_equal_to(0.0), + Check.less_than_or_equal_to(1.0), + ], + coerce=True, + ), + } +) + + +@dataclass +class RMSPETestResult(TestResultFrame): + """ + A subclass of TestResultFrame, RMSPETestResult stores test statistics and p-value + for the treated unit. Test statistics for pseudo treatment units are also stored. + """ + + treatment_test_results: tp.Dict[tp.Tuple[str, str], TestResult] + pseudo_treatment_test_statistics: tp.Dict[tp.Tuple[str, str], tp.List[Float]] + + def to_df(self) -> pd.DataFrame: + dfs = [] + for (model, dataset), test_results in self.treatment_test_results.items(): + result = { + "Model": model, + "Dataset": dataset, + "Test statistic": test_results.test_statistic, + "p-value": test_results.p_value, + } + df = pd.DataFrame([result]) + dfs.append(df) + df = pd.concat(dfs) + RMSPESchema.validate(df) + return df + + +@dataclass +class RMSPETest(PlaceboTest): + """ + A subclass of PlaceboTest calculates RMSPE as test statistic for all units. + Given the RMSPE test stats, p-value for actual treatment is calculated. + """ + + def execute(self, verbose: bool = True) -> RMSPETestResult: + treatment_results, pseudo_treatment_results = {}, {} + datasets = self.dataset_dict + n_datasets = len(datasets) + n_control = sum([d.n_units for d in datasets.values()]) + rmspe = RMSPETestStatistic() + with Progress(disable=not verbose) as progress: + model_task = progress.add_task( + "[red]Models", total=len(self.models), visible=verbose + ) + data_task = progress.add_task( + "[blue]Datasets", total=n_datasets, visible=verbose + ) + unit_task = progress.add_task( + f"[green]Treatment and Control Units", + total=n_control + 1, + visible=verbose, + ) + for data_name, dataset in datasets.items(): + progress.update(data_task, advance=1) + for model in self.models: + progress.update(unit_task, advance=1) + treatment_result = model(dataset) + treatment_idx = dataset.ytr.shape[0] + treatment_test_stat = rmspe( + dataset, + treatment_result.counterfactual, + treatment_result.synthetic, + treatment_idx, + ) + progress.update(model_task, advance=1) + placebo_test_stats = [] + for i in range(dataset.n_units): + progress.update(unit_task, advance=1) + placebo_data = dataset.to_placebo_data(i) + result = model(placebo_data) + placebo_test_stats.append( + rmspe( + placebo_data, + result.counterfactual, + result.synthetic, + treatment_idx, + ) + ) + pval_idx = 1 + for p_stat in placebo_test_stats: + pval_idx += 1 if treatment_test_stat < p_stat else 0 + pval = pval_idx / (n_control + 1) + treatment_results[(model._model_name, data_name)] = TestResult( + p_value=pval, test_statistic=treatment_test_stat + ) + pseudo_treatment_results[(model._model_name, data_name)] = ( + placebo_test_stats + ) + return RMSPETestResult( + treatment_test_results=treatment_results, + pseudo_treatment_test_statistics=pseudo_treatment_results, + ) diff --git a/src/causal_validation/validation/testing.py b/src/causal_validation/validation/testing.py new file mode 100644 index 0000000..e0144a2 --- /dev/null +++ b/src/causal_validation/validation/testing.py @@ -0,0 +1,107 @@ +import abc +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np +import pandas as pd +from rich import box +from rich.table import Table + +from causal_validation.data import Dataset + + +@dataclass +class TestResultFrame: + """A parent class for test results""" + + @abc.abstractmethod + def to_df(self) -> pd.DataFrame: + raise NotImplementedError + + def summary(self, precision: int = 4) -> Table: + table = Table(show_header=True, box=box.MARKDOWN) + df = self.to_df() + numeric_cols = df.select_dtypes(include=[np.number]) + df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision) + + for column in df.columns: + table.add_column(str(column), style="magenta") + + for _, value_list in enumerate(df.values.tolist()): + row = [str(x) for x in value_list] + table.add_row(*row) + + return table + + +@dataclass +class TestResult: + p_value: float + test_statistic: float + + +@dataclass +class AbstractTestStatistic: + @abc.abstractmethod + def _compute( + self, + dataset: Dataset, + counterfactual: Float[np.ndarray, "N 1"], + synthetic: tp.Optional[Float[np.ndarray, "M 1"]], + treatment_index: int, + ) -> Float: + raise NotImplementedError + + def __call__( + self, + observed: Float[np.ndarray, "N 1"], + counterfactual: Float[np.ndarray, "N 1"], + synthetic: tp.Optional[Float[np.ndarray, "M 1"]], + treatment_index: int, + ) -> Float: + return self._compute(observed, counterfactual, synthetic, treatment_index) + + +@dataclass +class RMSPETestStatistic(AbstractTestStatistic): + """ + Provided a dataset and treatment index together with counterfactuals and + synthetic control for the unit assigned to treatment, rmspe test statistic + is calculated. + """ + + @staticmethod + def _compute( + dataset: Dataset, + counterfactual: Float[np.ndarray, "N 1"], + synthetic: Float[np.ndarray, "N 1"], + treatment_index: int, + ) -> Float: + _, pre_observed = dataset.pre_intervention_obs + _, post_observed = dataset.post_intervention_obs + _, post_counterfactual = RMSPETestStatistic._split_array( + counterfactual, treatment_index + ) + pre_synthetic, _ = RMSPETestStatistic._split_array(synthetic, treatment_index) + pre_rmspe = RMSPETestStatistic._rmspe(pre_observed, pre_synthetic) + post_rmspe = RMSPETestStatistic._rmspe(post_observed, post_counterfactual) + if pre_rmspe == 0: + raise ZeroDivisionError("Error: pre intervention period MSPE is 0!") + else: + test_statistic = post_rmspe / pre_rmspe + return test_statistic + + @staticmethod + def _rmspe( + observed: Float[np.ndarray, "N 1"], generated: Float[np.ndarray, "N 1"] + ) -> float: + return np.sqrt(np.mean(np.square(observed - generated))) + + @staticmethod + def _split_array( + array: Float[np.ndarray, "N 1"], index: int + ) -> tp.Tuple[Float[np.ndarray, "Nx 1"], Float[np.ndarray, "Ny 1"]]: + left_split = array[:index, :] + right_split = array[index:, :] + return left_split, right_split diff --git a/tests/test_causal_validation/test_models.py b/tests/test_causal_validation/test_models.py index feeb1ce..ca479a4 100644 --- a/tests/test_causal_validation/test_models.py +++ b/tests/test_causal_validation/test_models.py @@ -7,7 +7,7 @@ JackKnife, ) from azcausal.core.estimator import Estimator -from azcausal.core.result import Result +from azcausal.core.result import Result as _Result from azcausal.estimators.panel import ( did, sdid, @@ -19,7 +19,10 @@ ) import numpy as np -from causal_validation.models import AZCausalWrapper +from causal_validation.models import ( + AZCausalWrapper, + Result, +) from causal_validation.testing import ( TestConstants, simulate_data, @@ -49,15 +52,20 @@ def test_call( n_post_treatment: int, seed: int, ): - constancts = TestConstants( + constants = TestConstants( N_CONTROL=n_control, N_PRE_TREATMENT=n_pre_treatment, N_POST_TREATMENT=n_post_treatment, ) - data = simulate_data(global_mean=10.0, seed=seed, constants=constancts) + data = simulate_data(global_mean=10.0, seed=seed, constants=constants) model = AZCausalWrapper(*model_error) result = model(data) assert isinstance(result, Result) assert isinstance(result.effect, Effect) assert not np.isnan(result.effect.value) + assert isinstance(model._az_result, _Result) + assert np.all(np.concatenate((data.ytr, data.yte), axis=0) == result.observed) + assert ( + result.observed.shape == result.counterfactual.shape == result.synthetic.shape + ) diff --git a/tests/test_causal_validation/test_validation/test_placebo.py b/tests/test_causal_validation/test_validation/test_placebo.py index b48d058..858c5f3 100644 --- a/tests/test_causal_validation/test_validation/test_placebo.py +++ b/tests/test_causal_validation/test_validation/test_placebo.py @@ -23,6 +23,7 @@ PlaceboTest, PlaceboTestResult, ) +from causal_validation.validation.testing import TestResultFrame def test_schema_coerce(): @@ -56,6 +57,7 @@ def test_placebo_test( # Check that the structure of result assert isinstance(result, PlaceboTestResult) + assert isinstance(result, TestResultFrame) for _, v in result.effects.items(): assert len(v) == n_control diff --git a/tests/test_causal_validation/test_validation/test_rmspe.py b/tests/test_causal_validation/test_validation/test_rmspe.py new file mode 100644 index 0000000..1bc6b37 --- /dev/null +++ b/tests/test_causal_validation/test_validation/test_rmspe.py @@ -0,0 +1,169 @@ +import typing as tp + +from azcausal.estimators.panel.did import DID +from azcausal.estimators.panel.sdid import SDID +from hypothesis import ( + given, + settings, + strategies as st, +) +import numpy as np +import pandas as pd +import pytest +from rich.table import Table + +from causal_validation.effects import StaticEffect +from causal_validation.models import AZCausalWrapper +from causal_validation.testing import ( + TestConstants, + simulate_data, +) +from causal_validation.transforms import Trend +from causal_validation.validation.rmspe import ( + RMSPESchema, + RMSPETest, + RMSPETestResult, +) +from causal_validation.validation.testing import ( + RMSPETestStatistic, + TestResult, + TestResultFrame, +) + + +def test_schema_coerce(): + df = RMSPESchema.example() + cols = df.columns + for col in cols: + if not col in ["Model", "Dataset"]: + df[col] = np.ceil((df[col])) + RMSPESchema.validate(df) + + +@given( + global_mean=st.floats(min_value=0.0, max_value=10.0), + seed=st.integers(min_value=0, max_value=1000000), + n_control=st.integers(min_value=10, max_value=20), + cf_inflate=st.one_of( + st.floats(min_value=1e-10, max_value=2.0), + st.floats(min_value=-2.0, max_value=-1e-10), + ), + s_inflate=st.one_of( + st.floats(min_value=1e-10, max_value=2.0), + st.floats(min_value=-2.0, max_value=-1e-10), + ), +) +@settings(max_examples=10) +def test_rmspe_test_stat( + global_mean: float, seed: int, n_control: int, cf_inflate: float, s_inflate: float +): + # Simulate data + constants = TestConstants(N_CONTROL=n_control, GLOBAL_SCALE=0.001) + data = simulate_data(global_mean=global_mean, seed=seed, constants=constants) + rmspe = RMSPETestStatistic() + counterfactual = np.concatenate((data.ytr, data.yte), axis=0) + cf_inflate + synthetic = counterfactual + assert rmspe( + data, counterfactual, synthetic, constants.N_PRE_TREATMENT + ) == pytest.approx(1.0) + + synthetic = np.concatenate((data.ytr, data.yte), axis=0) + s_inflate + assert rmspe( + data, counterfactual, synthetic, constants.N_PRE_TREATMENT + ) == pytest.approx(abs(cf_inflate) / abs(s_inflate)) + + synthetic = np.concatenate((data.ytr, data.yte), axis=0) + with pytest.raises( + ZeroDivisionError, match="Error: pre intervention period MSPE is 0!" + ): + rmspe(data, counterfactual, synthetic, constants.N_PRE_TREATMENT) + + +@given( + global_mean=st.floats(min_value=0.0, max_value=10.0), + effect=st.one_of( + st.floats(min_value=1.0, max_value=5.0), + st.floats(min_value=-5.0, max_value=-1.0), + ), + seed=st.integers(min_value=0, max_value=1000000), + n_control=st.integers(min_value=10, max_value=20), + model=st.sampled_from([DID(), SDID()]), +) +@settings(max_examples=10) +def test_rmspe_test( + global_mean: float, + effect: float, + seed: int, + n_control: int, + model: tp.Union[DID, SDID], +): + # Simulate data with a trend and effect + constants = TestConstants(N_CONTROL=n_control, GLOBAL_SCALE=0.001) + data = simulate_data(global_mean=global_mean, seed=seed, constants=constants) + trend_term = Trend(degree=1, coefficient=0.1) + static_effect = StaticEffect(effect=effect) + data = static_effect(trend_term(data)) + + model = AZCausalWrapper(model) + result = RMSPETest(model, data).execute() + + assert isinstance(result, RMSPETestResult) + assert isinstance(result, TestResultFrame) + assert set(result.treatment_test_results.keys()) == set( + result.pseudo_treatment_test_statistics.keys() + ) + + for k, v in result.treatment_test_results.items(): + assert isinstance(v, TestResult) + assert len(result.pseudo_treatment_test_statistics[k]) == n_control + + summary = result.to_df() + RMSPESchema.validate(summary) + assert isinstance(summary, pd.DataFrame) + assert summary.shape == (1, 4) + assert summary["p-value"].iloc[0] == pytest.approx(1.0 / (n_control + 1)) + + rich_summary = result.summary() + assert isinstance(rich_summary, Table) + n_rows = result.summary().row_count + assert n_rows == summary.shape[0] + + +@pytest.mark.parametrize("n_control", [9, 10]) +def test_multiple_models(n_control: int): + constants = TestConstants(N_CONTROL=n_control, GLOBAL_SCALE=0.001) + data = simulate_data(global_mean=20.0, seed=123, constants=constants) + trend_term = Trend(degree=1, coefficient=0.1) + data = trend_term(data) + + model1 = AZCausalWrapper(DID()) + model2 = AZCausalWrapper(SDID()) + result = RMSPETest([model1, model2], data).execute() + + result_df = result.to_df() + result_rich = result.summary() + assert result_df.shape == (2, 4) + assert result_df.shape[0] == result_rich.row_count + assert result_df["Model"].tolist() == ["DID", "SDID"] + for k, v in result.treatment_test_results.items(): + assert isinstance(v, TestResult) + assert len(result.pseudo_treatment_test_statistics[k]) == n_control + + +@given( + seeds=st.lists( + elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=5 + ) +) +@settings(max_examples=5) +def test_multiple_datasets(seeds: tp.List[int]): + data = [simulate_data(global_mean=20.0, seed=s) for s in seeds] + n_data = len(data) + + model = AZCausalWrapper(DID()) + result = RMSPETest(model, data).execute() + + result_df = result.to_df() + result_rich = result.summary() + assert result_df.shape == (n_data, 4) + assert result_df.shape[0] == result_rich.row_count