diff --git a/src/superphot_plus/__init__.py b/src/superphot_plus/__init__.py index e69de29b..55768ed8 100644 --- a/src/superphot_plus/__init__.py +++ b/src/superphot_plus/__init__.py @@ -0,0 +1 @@ +from . import data_generation diff --git a/src/superphot_plus/data_generation/__init__.py b/src/superphot_plus/data_generation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data/generation/make_fake_spp_data.py b/src/superphot_plus/data_generation/make_fake_spp_data.py similarity index 84% rename from data/generation/make_fake_spp_data.py rename to src/superphot_plus/data_generation/make_fake_spp_data.py index c88b1254..651dd9a6 100644 --- a/data/generation/make_fake_spp_data.py +++ b/src/superphot_plus/data_generation/make_fake_spp_data.py @@ -117,31 +117,41 @@ def create_prior(cube, tdata): return cube -def create_clean_models(nmodels): +def create_clean_models(nmodels, num_times=100): + """Generate 'clean' (noiseless) models from the prior + + Parameters + ---------- + nmodels : int + The number of models you want to generate + num_times : int, optional + The number of timesteps to use. + + Returns + ------- + params : array-like of numpy arrays + The array of parameters used to generate each model + lcs : array-like of numpy arrays + The array of individual light curves for each model generated """ - Generate "clean" (noiseless) models from the prior + params = [] + lcs = [] - inputs: - nmodels : number of models you want to generate + tdata = np.linspace(-100, 100, num_times) + bdata = np.asarray(["g"] * num_times, dtype=str) + edata = np.asarray([1e-6] * num_times, dtype=float) - outputs: - params : set of parameters used to generate each model - lcs : light curves for each model generated - """ - for i in np.arange(nmodels): + while len(lcs) < nmodels: cube = np.random.uniform(0, 1, 14) - tdata = np.linspace(-100, 100, 100) cube = create_prior(cube, tdata) A, beta, gamma, t0, tau_rise, tau_fall, es = cube[:7] # pylint: disable=unused-variable + # Try again if we picked invalid priors. if not params_valid(A, beta, gamma, t0, tau_rise, tau_fall): continue - print(cube) - bdata = np.asarray(["g"] * 100, dtype=str) - f_model = flux_model(cube, tdata, bdata) - plt.plot(tdata, f_model, ".") + params.append(cube) + f_model = flux_model(cube, tdata, bdata) + lcs.append(np.array([tdata, f_model, edata, bdata])) -# ASHLEY is still working on this... -create_clean_models(100) -plt.show() + return params, lcs diff --git a/src/superphot_plus/file_utils.py b/src/superphot_plus/file_utils.py index 5b5ac019..7f0b6693 100644 --- a/src/superphot_plus/file_utils.py +++ b/src/superphot_plus/file_utils.py @@ -1,5 +1,7 @@ """Methods for reading and writing input, intermediate, and output files.""" +import os + import numpy as np @@ -41,3 +43,34 @@ def read_single_lightcurve(filename, time_ceiling=None): t -= max_flux_loc # make relative return t, f, ferr, b + + +def save_single_lightcurve(filename, times, fluxes, errors, bands, compressed=True, overwrite=False): + """ + Write a single lightcurve data file. + + Parameters + ---------- + filename : str + Name of the data file including path. + times : array-like + The light curve time data. + fluxes : array-like + The light curve flux data. + errors : array-like + The light curve error data. + bands : array-like + The light curve band data. + compressed : bool, optional + Whether to save in compressed format. + overwrite : bool, optional + Whether to overwrite existing data. + """ + if os.path.exists(filename) and not overwrite: + raise FileExistsError(f"ERROR: File already exists {filename}") + + lcs = np.array([times, fluxes, errors, bands]) + if compressed: + np.savez_compressed(filename, lcs) + else: + np.savez(filename, lcs) diff --git a/src/superphot_plus/ztf_transient_fit.py b/src/superphot_plus/ztf_transient_fit.py index c3417e11..83d02e6b 100644 --- a/src/superphot_plus/ztf_transient_fit.py +++ b/src/superphot_plus/ztf_transient_fit.py @@ -396,8 +396,6 @@ def dynesty_single_file(test_fn, output_dir, skip_if_exists=True, rstate=None): Return the mean of the MCMC samples or None if the fitting is skipped or encounters an error. """ - # try: - os.makedirs(output_dir, exist_ok=True) prefix = test_fn.split("/")[-1][:-4] if skip_if_exists and os.path.exists(os.path.join(output_dir, f"{prefix}_eqwt.npz")): diff --git a/tests/superphot_plus/data_generation/test_make_fake_spp_data.py b/tests/superphot_plus/data_generation/test_make_fake_spp_data.py new file mode 100644 index 00000000..5b94be02 --- /dev/null +++ b/tests/superphot_plus/data_generation/test_make_fake_spp_data.py @@ -0,0 +1,15 @@ +import os + +import numpy as np + +from superphot_plus.data_generation.make_fake_spp_data import create_clean_models + + +def test_generate_clean_data(): + # Generate 10 light curves with 50 time steps each. + params, lcs = create_clean_models(10, 50) + assert len(params) == 10 + assert len(lcs) == 10 + for i in range(10): + assert lcs[i].shape == (4, 50) + assert len(params[i]) == 14 diff --git a/tests/superphot_plus/test_file_utils.py b/tests/superphot_plus/test_file_utils.py index f54c122d..9157e40b 100644 --- a/tests/superphot_plus/test_file_utils.py +++ b/tests/superphot_plus/test_file_utils.py @@ -1,4 +1,8 @@ -from superphot_plus.file_utils import read_single_lightcurve +import numpy as np +import os +import pytest + +from superphot_plus.file_utils import read_single_lightcurve, save_single_lightcurve def test_read_single_lightcurve(single_ztf_lightcurve_compressed): @@ -12,6 +16,48 @@ def test_read_single_lightcurve(single_ztf_lightcurve_compressed): assert len(band) == 19 +def test_write_and_read_single_lightcurve(tmp_path): + # Create fake data. Note that the first point in the fluxes must be the brightest + # and the first time stamp must be zero, because of how read_single_lightcurve + # shifts the times to be zero at the peak. + times = np.array(range(10)) + fluxes = np.array([100.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.1]) + bands = np.array(["r"] * 10) + errors = np.array([0.1] * 10) + + filename = os.path.join(tmp_path, "my_test_file.npz") + save_single_lightcurve(filename, times, fluxes, errors, bands, overwrite=True) + + # Re-read and check data. + t2, f2, e2, b2 = read_single_lightcurve(filename) + assert np.allclose(t2, times) + assert np.allclose(f2, fluxes) + assert np.allclose(e2, errors) + + # If we try to save again (without overwrite=True) we should get an error. + with pytest.raises(FileExistsError): + save_single_lightcurve(filename, times, fluxes, errors, bands, overwrite=False) + + +def test_write_and_read_uncompressed_lightcurve(tmp_path): + # Create fake data. Note that the first point in the fluxes must be the brightest + # and the first time stamp must be zero, because of how read_single_lightcurve + # shifts the times to be zero at the peak. + times = np.array(range(10)) + fluxes = np.array([100.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.1]) + bands = np.array(["r"] * 10) + errors = np.array([0.1] * 10) + + filename = os.path.join(tmp_path, "my_test_file.npz") + save_single_lightcurve(filename, times, fluxes, errors, bands, compressed=False, overwrite=True) + + # Re-read and check data. + t2, f2, e2, b2 = read_single_lightcurve(filename) + assert np.allclose(t2, times) + assert np.allclose(f2, fluxes) + assert np.allclose(e2, errors) + + def test_read_single_lightcurve_with_time_celing(single_ztf_lightcurve_compressed): """Test that we can load a single light curve from pickled file, restricting the time window for events."""