Skip to content

Commit

Permalink
correct sampled variables in sample_prior_predictive log call & add r…
Browse files Browse the repository at this point in the history
…eturn type overloads
  • Loading branch information
Goose committed Mar 3, 2025
1 parent d1aff0b commit b512717
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Any,
TypeAlias,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -360,6 +361,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None):
]


@overload
def sample_prior_predictive(
draws: int = 500,
model: Model | None = None,
var_names: Iterable[str] | None = None,
random_seed: RandomState = None,
return_inferencedata: bool = True,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
samples: int | None = None,
) -> InferenceData: ...
@overload
def sample_prior_predictive(
draws: int = 500,
model: Model | None = None,
var_names: Iterable[str] | None = None,
random_seed: RandomState = None,
return_inferencedata: bool = False,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
samples: int | None = None,
) -> dict[str, np.ndarray]: ...
def sample_prior_predictive(
draws: int = 500,
model: Model | None = None,
Expand Down Expand Up @@ -449,7 +472,7 @@ def sample_prior_predictive(
)

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]
_log.info(f"Sampling: {sorted(vars_to_sample, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]
values = zip(*(sampler_fn() for i in range(draws)))

data = {k: np.stack(v) for k, v in zip(names, values)}
Expand Down

0 comments on commit b512717

Please sign in to comment.