Skip to content

Commit

Permalink
Merge pull request #35 from lincc-frameworks/gen_clean_models
Browse files Browse the repository at this point in the history
Improve testing support
  • Loading branch information
kdesoto-astro authored Jul 20, 2023
2 parents 5d6344d + 2e4f0fb commit 18298db
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/superphot_plus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import data_generation
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,41 @@ def create_prior(cube, tdata):
return cube


def create_clean_models(nmodels):
def create_clean_models(nmodels, num_times=100):
"""Generate 'clean' (noiseless) models from the prior
Parameters
----------
nmodels : int
The number of models you want to generate
num_times : int, optional
The number of timesteps to use.
Returns
-------
params : array-like of numpy arrays
The array of parameters used to generate each model
lcs : array-like of numpy arrays
The array of individual light curves for each model generated
"""
Generate "clean" (noiseless) models from the prior
params = []
lcs = []

inputs:
nmodels : number of models you want to generate
tdata = np.linspace(-100, 100, num_times)
bdata = np.asarray(["g"] * num_times, dtype=str)
edata = np.asarray([1e-6] * num_times, dtype=float)

outputs:
params : set of parameters used to generate each model
lcs : light curves for each model generated
"""
for i in np.arange(nmodels):
while len(lcs) < nmodels:
cube = np.random.uniform(0, 1, 14)
tdata = np.linspace(-100, 100, 100)
cube = create_prior(cube, tdata)
A, beta, gamma, t0, tau_rise, tau_fall, es = cube[:7] # pylint: disable=unused-variable

# Try again if we picked invalid priors.
if not params_valid(A, beta, gamma, t0, tau_rise, tau_fall):
continue
print(cube)
bdata = np.asarray(["g"] * 100, dtype=str)
f_model = flux_model(cube, tdata, bdata)
plt.plot(tdata, f_model, ".")
params.append(cube)

f_model = flux_model(cube, tdata, bdata)
lcs.append(np.array([tdata, f_model, edata, bdata]))

# ASHLEY is still working on this...
create_clean_models(100)
plt.show()
return params, lcs
33 changes: 33 additions & 0 deletions src/superphot_plus/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Methods for reading and writing input, intermediate, and output files."""

import os

import numpy as np


Expand Down Expand Up @@ -41,3 +43,34 @@ def read_single_lightcurve(filename, time_ceiling=None):
t -= max_flux_loc # make relative

return t, f, ferr, b


def save_single_lightcurve(filename, times, fluxes, errors, bands, compressed=True, overwrite=False):
"""
Write a single lightcurve data file.
Parameters
----------
filename : str
Name of the data file including path.
times : array-like
The light curve time data.
fluxes : array-like
The light curve flux data.
errors : array-like
The light curve error data.
bands : array-like
The light curve band data.
compressed : bool, optional
Whether to save in compressed format.
overwrite : bool, optional
Whether to overwrite existing data.
"""
if os.path.exists(filename) and not overwrite:
raise FileExistsError(f"ERROR: File already exists {filename}")

lcs = np.array([times, fluxes, errors, bands])
if compressed:
np.savez_compressed(filename, lcs)
else:
np.savez(filename, lcs)
2 changes: 0 additions & 2 deletions src/superphot_plus/ztf_transient_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,6 @@ def dynesty_single_file(test_fn, output_dir, skip_if_exists=True, rstate=None):
Return the mean of the MCMC samples or None if the fitting is
skipped or encounters an error.
"""
# try:

os.makedirs(output_dir, exist_ok=True)
prefix = test_fn.split("/")[-1][:-4]
if skip_if_exists and os.path.exists(os.path.join(output_dir, f"{prefix}_eqwt.npz")):
Expand Down
15 changes: 15 additions & 0 deletions tests/superphot_plus/data_generation/test_make_fake_spp_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

import numpy as np

from superphot_plus.data_generation.make_fake_spp_data import create_clean_models


def test_generate_clean_data():
# Generate 10 light curves with 50 time steps each.
params, lcs = create_clean_models(10, 50)
assert len(params) == 10
assert len(lcs) == 10
for i in range(10):
assert lcs[i].shape == (4, 50)
assert len(params[i]) == 14
48 changes: 47 additions & 1 deletion tests/superphot_plus/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from superphot_plus.file_utils import read_single_lightcurve
import numpy as np
import os
import pytest

from superphot_plus.file_utils import read_single_lightcurve, save_single_lightcurve


def test_read_single_lightcurve(single_ztf_lightcurve_compressed):
Expand All @@ -12,6 +16,48 @@ def test_read_single_lightcurve(single_ztf_lightcurve_compressed):
assert len(band) == 19


def test_write_and_read_single_lightcurve(tmp_path):
# Create fake data. Note that the first point in the fluxes must be the brightest
# and the first time stamp must be zero, because of how read_single_lightcurve
# shifts the times to be zero at the peak.
times = np.array(range(10))
fluxes = np.array([100.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.1])
bands = np.array(["r"] * 10)
errors = np.array([0.1] * 10)

filename = os.path.join(tmp_path, "my_test_file.npz")
save_single_lightcurve(filename, times, fluxes, errors, bands, overwrite=True)

# Re-read and check data.
t2, f2, e2, b2 = read_single_lightcurve(filename)
assert np.allclose(t2, times)
assert np.allclose(f2, fluxes)
assert np.allclose(e2, errors)

# If we try to save again (without overwrite=True) we should get an error.
with pytest.raises(FileExistsError):
save_single_lightcurve(filename, times, fluxes, errors, bands, overwrite=False)


def test_write_and_read_uncompressed_lightcurve(tmp_path):
# Create fake data. Note that the first point in the fluxes must be the brightest
# and the first time stamp must be zero, because of how read_single_lightcurve
# shifts the times to be zero at the peak.
times = np.array(range(10))
fluxes = np.array([100.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.1])
bands = np.array(["r"] * 10)
errors = np.array([0.1] * 10)

filename = os.path.join(tmp_path, "my_test_file.npz")
save_single_lightcurve(filename, times, fluxes, errors, bands, compressed=False, overwrite=True)

# Re-read and check data.
t2, f2, e2, b2 = read_single_lightcurve(filename)
assert np.allclose(t2, times)
assert np.allclose(f2, fluxes)
assert np.allclose(e2, errors)


def test_read_single_lightcurve_with_time_celing(single_ztf_lightcurve_compressed):
"""Test that we can load a single light curve from pickled file,
restricting the time window for events."""
Expand Down

0 comments on commit 18298db

Please sign in to comment.