-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathLSST_Source.py
211 lines (150 loc) · 8.65 KB
/
LSST_Source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
from astropy.table import Table
from astropy import units as u
from astropy.coordinates import SkyCoord
from collections import OrderedDict
from taxonomy import get_classification_labels, get_astrophysical_class, plot_colored_tree
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
class LSST_Source:
# List of time series features actually stored in the instance of the class.
time_series_features = ['MJD', 'BAND', 'PHOTFLAG', 'FLUXCAL', 'FLUXCALERR']
# List of other features actually stored in the instance of the class.
other_features = ['RA', 'DEC', 'MWEBV', 'MWEBV_ERR', 'REDSHIFT_HELIO', 'REDSHIFT_HELIO_ERR', 'VPEC', 'VPEC_ERR', 'HOSTGAL_PHOTOZ', 'HOSTGAL_PHOTOZ_ERR', 'HOSTGAL_SPECZ', 'HOSTGAL_SPECZ_ERR', 'HOSTGAL_RA', 'HOSTGAL_DEC', 'HOSTGAL_SNSEP', 'HOSTGAL_DDLR', 'HOSTGAL_LOGMASS', 'HOSTGAL_LOGMASS_ERR', 'HOSTGAL_LOGSFR', 'HOSTGAL_LOGSFR_ERR', 'HOSTGAL_LOGsSFR', 'HOSTGAL_LOGsSFR_ERR', 'HOSTGAL_COLOR', 'HOSTGAL_COLOR_ERR', 'HOSTGAL_ELLIPTICITY', 'HOSTGAL_MAG_u', 'HOSTGAL_MAG_g', 'HOSTGAL_MAG_r', 'HOSTGAL_MAG_i', 'HOSTGAL_MAG_z', 'HOSTGAL_MAG_Y', 'HOSTGAL_MAGERR_u', 'HOSTGAL_MAGERR_g', 'HOSTGAL_MAGERR_r', 'HOSTGAL_MAGERR_i', 'HOSTGAL_MAGERR_z', 'HOSTGAL_MAGERR_Y']
# Additional features computed based on time_series_features and other_features mentioned in SNANA fits.
custom_engineered_features = ['MW_plane_flag', 'ELAIS_S1_flag', 'XMM-LSS_flag', 'Extended_Chandra_Deep_Field-South_flag', 'COSMOS_flag']
# Get the mean wavelengths for each filter and then convert to micro meters
pb_wavelengths = {
'u': (320 + 400) / (2 * 1000),
'g': (400 + 552) / (2 * 1000),
'r': (552 + 691) / (2 * 1000),
'i': (691 + 818) / (2 * 1000),
'z': (818 + 922) / (2 * 1000),
'Y': (950 + 1080) / (2 * 1000),
}
# Pass band to color dict
colors = OrderedDict({
'u': 'blue',
'g': 'green',
'r': 'red',
'i': 'teal',
'z': 'orange',
'Y': 'purple',
})
# 6 broadband filters used in LSST.
LSST_bands = list(colors.keys())
# Coordinates for LSST's 4 selected deep drilling fields. (Reference: https://www.lsst.org/scientists/survey-design/ddf)
LSST_DDF = {
'ELAIS_S1': SkyCoord(l=311.30 * u.deg, b=-72.90 * u.deg, frame='galactic'),
'XMM-LSS': SkyCoord(l=171.20 * u.deg, b=-58.77 * u.deg, frame='galactic'),
'Extended_Chandra_Deep_Field-South': SkyCoord(l=224.07 * u.deg, b=-54.47 * u.deg, frame='galactic'),
'COSMOS': SkyCoord(l=236.83 * u.deg, b=42.09 * u.deg, frame='galactic'),
}
# Threshold values
# MW_plane_flag is set to one if |self.b| <= b_threshold. Indicative of weather the object is in the galactic plane.
b_threshold = 15
# Flux scaling value
flux_scaling_const = 1000
# Radius of the deep drilling field for LSST, in degrees.
ddf_separation_radius_threshold = 3.5 / 2
def __init__(self, parquet_row) -> None:
"""Create an LSST_Source object to store both photometric and host galaxy data from the Elasticc simulations.
Args:
parquet_row (_type_): A row from the polars data frame that was generated from the Elasticc FITS files using fits_to_parquet.py
class_label (str): The Elasticc class label for this LSST_Source object.
"""
# Set all the class attributes
setattr(self, 'ELASTICC_class', parquet_row['ELASTICC_class'].to_numpy()[0])
setattr(self, 'SNID', parquet_row['SNID'].to_numpy()[0])
setattr(self, 'astrophysical_class', get_astrophysical_class(self.ELASTICC_class))
for key in parquet_row.columns:
if key in self.other_features:
setattr(self, key, parquet_row[key].to_numpy()[0])
elif key in self.time_series_features:
setattr(self, key, parquet_row[key][0].to_numpy())
# Run processing code on the light curves
self.process_lightcurve()
# Computer additional features
self.compute_custom_features()
def process_lightcurve(self) -> None:
"""Process the flux information with phot flags. Processing is done using the following steps:
1. Remove saturations.
Finally, all the time series data is modified to conform to the steps mentioned above.
"""
# Remove saturations from the light curves
saturation_mask = (self.PHOTFLAG & 1024) == 0
# Alter time series data to remove saturations
for time_series_feature in self.time_series_features:
setattr(self, time_series_feature, getattr(self, time_series_feature)[saturation_mask])
def compute_custom_features(self) -> None:
source_coord = SkyCoord(ra = self.RA * u.deg, dec=self.DEC * u.deg)
# Check if the object is close to the galactic plane of the milky way
if abs(source_coord.galactic.b.degree) < self.b_threshold:
self.MW_plane_flag = 1
else:
self.MW_plane_flag = 0
# Check if the object is in one of 4 LSST DDF's and set flags appropriately
for key in self.LSST_DDF:
# Separation from field center
separation = source_coord.separation(self.LSST_DDF[key]).degree
if separation < self.ddf_separation_radius_threshold:
setattr(self, f'{key}_flag', 1)
else:
setattr(self, f'{key}_flag', 0)
pass
def plot_flux_curve(self) -> None:
"""Plot the SNANA calibrated flux vs time plot for all the data in the processed time series. All detections are marked with a star while non detections are marked with dots. Observations are color codded by their passband. This function is fundamentally a visualization tool and is not intended for making plots for papers.
"""
# Colorize the data
c = [self.colors[band] for band in self.BAND]
patches = [mpatches.Patch(color=self.colors[band], label=band, linewidth=1) for band in self.colors]
fmts = np.where((self.PHOTFLAG & 4096) != 0, '*', '.')
# Plot flux time series
for i in range(len(self.MJD)):
plt.errorbar(x=self.MJD[i], y=self.FLUXCAL[i], yerr=self.FLUXCALERR[i], color=c[i], fmt=fmts[i], markersize = '10')
# Labels
plt.title(f"SNID: {self.SNID} | CLASS: {self.ELASTICC_class}")
plt.xlabel('Time (MJD)')
plt.ylabel('Calibrated Flux')
plt.legend(handles=patches)
plt.show()
def get_classification_labels(self):
"""Get the classification labels (hierarchical) for this LSST Source object in the Taxonomy tree.
Returns:
(tree_nodes, numerical_labels): A tuple containing two list like objects. The first object contains the ordering of the nodes. The second list contains the labels themselves (0 when the object does not belong to the class and 1 when it does). The labels in the second object correspond to the nodes in the first object.
"""
return get_classification_labels(self.astrophysical_class)
def plot_classification_tree(self):
"""Plot the classification tree (based on our taxonomy) for this LSST Source object.
"""
node, labels = self.get_classification_labels()
plot_colored_tree(labels)
def get_event_table(self):
# Dataframe for time series data
table = Table()
# Find time since last observation
time_since_first_obs = self.MJD - self.MJD[0]
table['scaled_time_since_first_obs'] = time_since_first_obs / 100
# 1 if it was a detection, zero otherwise
table['detection_flag'] = np.where((self.PHOTFLAG & 4096 != 0), 1, 0)
# Transform flux cal and flux cal err to more manageable values (more consistent order of magnitude)
table['scaled_FLUXCAL'] = self.FLUXCAL / self.flux_scaling_const
table['scaled_FLUXCALERR'] = self.FLUXCALERR / self.flux_scaling_const
# One hot encoding for the pass band
table['band_label'] = [self.pb_wavelengths[pb] for pb in self.BAND]
# Consistency check
assert len(table) == len(self.MJD), "Length of time series tensor does not match the number of mjd values."
# Array for static features
feature_static = OrderedDict()
for other_feature in self.other_features:
feature_static[other_feature] = getattr(self, other_feature)
for feature in self.custom_engineered_features:
feature_static[feature] = getattr(self, feature)
# Array for computed static features
table.meta = feature_static
return table
def __str__(self) -> str:
to_return = str(vars(self))
return to_return
if __name__=='__main__':
pass