diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 5f3c983b5..0f47325fe 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -22,6 +22,7 @@ Any, TypeAlias, cast, + overload, ) import numpy as np @@ -360,6 +361,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None): ] +@overload +def sample_prior_predictive( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: bool = True, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, + samples: int | None = None, +) -> InferenceData: ... +@overload +def sample_prior_predictive( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: bool = False, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, + samples: int | None = None, +) -> dict[str, np.ndarray]: ... def sample_prior_predictive( draws: int = 500, model: Model | None = None, @@ -449,7 +472,7 @@ def sample_prior_predictive( ) # All model variables have a name, but mypy does not know this - _log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] + _log.info(f"Sampling: {sorted(vars_to_sample, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] values = zip(*(sampler_fn() for i in range(draws))) data = {k: np.stack(v) for k, v in zip(names, values)}