Skip to content

Commit

Permalink
Merge pull request #32 from lincc-frameworks/jkubica/test_dynesty
Browse files Browse the repository at this point in the history
Add to the dynesty sampler test
  • Loading branch information
jeremykubica authored Jul 20, 2023
2 parents 77b8543 + 9f9b9a0 commit 5d6344d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions src/superphot_plus/ztf_transient_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,9 @@ def dynesty_single_file(test_fn, output_dir, skip_if_exists=True, rstate=None):
Returns
-------
None
Returns None if the fitting is skipped or encounters an error.
sample_mean: numpy array
Return the mean of the MCMC samples or None if the fitting is
skipped or encounters an error.
"""
# try:

Expand All @@ -405,6 +406,8 @@ def dynesty_single_file(test_fn, output_dir, skip_if_exists=True, rstate=None):
eq_samples = run_mcmc(test_fn, plot=False, rstate=rstate)
if eq_samples is None:
return None
print(np.mean(eq_samples, axis=0))
sample_mean = np.mean(eq_samples, axis=0)
print(sample_mean)

np.savez_compressed(os.path.join(output_dir, f"{prefix}_eqwt_dynesty.npz"), eq_samples)
return sample_mean
4 changes: 2 additions & 2 deletions tests/superphot_plus/test_ztf_transient_fit.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import os

import numpy as np

from superphot_plus.ztf_transient_fit import dynesty_single_file


def test_dynesty_single_file(tmp_path, single_ztf_lightcurve_compressed):
"""Just test that we generated a new file with fits"""
dynesty_single_file(
sample_mean = dynesty_single_file(
single_ztf_lightcurve_compressed,
tmp_path,
skip_if_exists=False,
rstate=np.random.default_rng(9876),
)
assert len(sample_mean) == 14

output_file = os.path.join(tmp_path, "ZTF22abvdwik_eqwt_dynesty.npz")
assert os.path.exists(output_file)
Expand Down

0 comments on commit 5d6344d

Please sign in to comment.