Skip to content

Commit

Permalink
Merge pull request #161 from VTDA-Group/additional_small_changes
Browse files Browse the repository at this point in the history
hotfix to allow old LCs to be read
  • Loading branch information
kdesoto-astro authored Sep 25, 2023
2 parents e3e188d + cc2d102 commit 6dc9f47
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
13 changes: 6 additions & 7 deletions src/superphot_plus/data_generation/make_fake_spp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,20 @@ def create_ztf_model(plot=False):
tdata = np.random.uniform(-100, 100, num_observations)
filter_data = np.random.choice(["g", "r"], size=num_observations)

cube = create_prior(cube)
A, beta, gamma, t0, tau_rise, tau_fall, es = cube[:7] # pylint: disable=unused-variable

found_valid = False
num_tried = 0

# Now re-attempts to regenerate until it gets a "good" model
while not found_valid and num_tried < 100:
cube = create_prior(cube)
A, beta, gamma, t0, tau_rise, tau_fall, es = cube[:7] # pylint: disable=unused-variable
params = create_prior(np.copy(cube))
A, beta, gamma, t0, tau_rise, tau_fall, es = params[:7] # pylint: disable=unused-variable
found_valid = params_valid(beta, gamma, tau_rise, tau_fall)
num_tried += 1

if not found_valid:
return "Failure"

f_model = flux_model(cube, tdata, filter_data, ["g", "r"], "r")
f_model = flux_model(params, tdata, filter_data, ["r", "g"], "r")
snr = ztf_noise_model(f_model, filter_data)

gind = np.where(snr > 3) # any points with SNR < 3 are ignored
Expand All @@ -238,7 +235,9 @@ def create_ztf_model(plot=False):
plt.xlabel("Time (days)")
plt.ylabel("Flux (arbitrary units)")
plt.show()
return cube[:7], tdata, filter_data, dirty_model, sigmas
return params[:7], tdata, filter_data, dirty_model, sigmas


# Can run this with create_model(plot=True)
if __name__ == "__main__":
create_ztf_model(plot=True)
7 changes: 6 additions & 1 deletion src/superphot_plus/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def from_file(cls, filename, ref_band="r", t0_lim=None, shift_time=True):
arr = None
property_dict = {}
for k in npy_array.files:
if k == "lcs":
if k == "lcs" or k == "arr_0": #hotfix to handle old LC format
arr = npy_array[k]
else:
property_dict[k] = npy_array[k]
Expand All @@ -340,6 +340,11 @@ def from_file(cls, filename, ref_band="r", t0_lim=None, shift_time=True):
fdata = arr[1][good_rows].astype(float)
edata = arr[2][good_rows].astype(float)
bdata = arr[3][good_rows]

if 'name' not in property_dict:
file_prefix = filename.split("/")[-1].split(".")[0]
property_dict['name'] = file_prefix

lc = Lightcurve(
tdata, fdata, edata, bdata,
**property_dict
Expand Down
5 changes: 5 additions & 0 deletions src/superphot_plus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def params_valid(beta, gamma, tau_rise, tau_fall):
bool
True if parameters are valid, False otherwise.
"""
if np.any(np.isnan(
[beta, gamma, tau_rise, tau_fall]
)):
return False

if tau_fall > 1.0 / beta:
return False

Expand Down

0 comments on commit 6dc9f47

Please sign in to comment.