Skip to content

Commit

Permalink
Merge pull request #749 from stan-dev/fix/2.35-fixes
Browse files Browse the repository at this point in the history
Fix tests for new cmdstan
  • Loading branch information
WardBrian authored May 13, 2024
2 parents cbea79f + 1af8596 commit fad7a69
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
6 changes: 4 additions & 2 deletions cmdstanpy/stanfit/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ def is_resampled(self) -> bool:
"""
return ( # type: ignore
self._metadata.cmdstan_config.get("num_paths", 4) > 1
and self._metadata.cmdstan_config.get('psis_resample', 1) == 1
and self._metadata.cmdstan_config.get('calculate_lp', 1) == 1
and self._metadata.cmdstan_config.get('psis_resample', 1)
in (1, 'true')
and self._metadata.cmdstan_config.get('calculate_lp', 1)
in (1, 'true')
)

def save_csvfiles(self, dir: Optional[str] = None) -> None:
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_sampler_csv(
)
)
if save_warmup:
if not ('save_warmup' in meta and meta['save_warmup'] == 1):
if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')):
raise ValueError(
'bad Stan CSV file {}, '
'config error, expected save_warmup = 1'.format(path)
Expand Down
9 changes: 5 additions & 4 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,13 @@ def test_variables_3d() -> None:
)
vars_iters = multidim_mle_iters.stan_variables(inc_iterations=True)
assert len(vars_iters) == len(multidim_mle_iters.metadata.stan_vars)
assert 'frac_60' in vars_iters
n_iter = vars_iters['frac_60'].shape[0]
assert n_iter > 1
assert 'y_rep' in vars_iters
assert vars_iters['y_rep'].shape == (8, 5, 4, 3)
assert vars_iters['y_rep'].shape == (n_iter, 5, 4, 3)
assert 'beta' in vars_iters
assert vars_iters['beta'].shape == (8, 2)
assert 'frac_60' in vars_iters
assert vars_iters['frac_60'].shape == (8,)
assert vars_iters['beta'].shape == (n_iter, 2)


def test_optimize_good() -> None:
Expand Down

0 comments on commit fad7a69

Please sign in to comment.