Skip to content

Commit

Permalink
Minor changes from reviewing
Browse files Browse the repository at this point in the history
  • Loading branch information
whohensee committed Jun 27, 2024
1 parent 66857ca commit 83f13ad
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 64 deletions.
7 changes: 3 additions & 4 deletions models/cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@ def __init__(self, *args, **kwargs):

def __getitem__(self, key):
if key not in self.keys():
# check if the key exists on disk
if self.cutouts.filepath is not None:
self.cutouts.load_one_co_dict(key)
self.cutouts.load_one_co_dict(key) # no change if not found
return super().__getitem__(key)

class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, HasBitFlagBadness):

__tablename__ = 'cutouts'

# a unique constraint on the provenance and the source list, but also on the index in the list
# a unique constraint on the provenance and the source list
__table_args__ = (
UniqueConstraint(
'sources_id', 'provenance_id', name='_cutouts_sources_provenance_uc'
Expand Down Expand Up @@ -190,7 +189,7 @@ def __repr__(self):
)

@staticmethod
def get_data_dict_attributes(include_optional=True): # WHPR could rename get_data_attributes
def get_data_dict_attributes(include_optional=True):
names = []
for im in ['sub', 'ref', 'new']:
for att in ['data', 'weight', 'flags']:
Expand Down
3 changes: 0 additions & 3 deletions models/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,6 @@ def __setattr__(self, key, value):

super().__setattr__(key, value)

# figure out if we need to include optional (probably yes)
# revisit after deciding below question, as I think optional
# are never used ATM
def get_data_from_cutouts(self):
"""Populates this object with the cutout data arrays used in
calculations. This allows us to use, for example, self.sub_data
Expand Down
3 changes: 1 addition & 2 deletions pipeline/cutting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run(self, *args, **kwargs):
if cutouts is not None:
cutouts.load_all_co_data()

if cutouts is None or cutouts.co_dict == {}:
if cutouts is None or len(cutouts.co_dict) == 0:

self.has_recalculated = True
# use the latest source list in the data store,
Expand Down Expand Up @@ -119,7 +119,6 @@ def run(self, *args, **kwargs):
cutouts._upstream_bitflag = 0
cutouts._upstream_bitflag |= detections.bitflag

cutouts.co_dict = {}
for i, source in enumerate(detections.data):
data_dict = {}
data_dict["sub_data"] = sub_stamps_data[i]
Expand Down
15 changes: 6 additions & 9 deletions pipeline/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def get_cutouts(self, provenance=None, session=None):

if self.cutouts is not None:
self.cutouts.load_all_co_data()
if self.cutouts.co_dict == {}:
if len(self.cutouts.co_dict) == 0:
self.cutouts = None # TODO: what about images that actually don't have any detections?

# make sure the cutouts have the correct provenance
Expand Down Expand Up @@ -1315,9 +1315,6 @@ def get_cutouts(self, provenance=None, session=None):
Cutouts.provenance_id == provenance.id,
)
).first()
# cutouts has a unique constraint with sources_id and provenance_id
# so a check I wrote when using all() that there was only 1
# is totally redundant, i think

return self.cutouts

Expand Down Expand Up @@ -1361,11 +1358,10 @@ def get_measurements(self, provenance=None, session=None):
if self.measurements is None:
with SmartSession(session, self.session) as session:
cutouts = self.get_cutouts(session=session)
cutout_ids = [cutouts.id] # WHPR this is inelegant. It works, but should fix

self.measurements = session.scalars(
sa.select(Measurements).where(
Measurements.cutouts_id.in_(cutout_ids),
Measurements.cutouts_id == cutouts.id,
Measurements.provenance_id == provenance.id,
)
).all()
Expand Down Expand Up @@ -1486,7 +1482,8 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False,
if obj is None:
continue

# TODO need to change this as cutouts changes to just be a cutoutsfile
# QUESTION: This whole block can never be reached, as cutouts is not a list and measurements
# don't save to disk. I want to kill it all. Objections?
if isinstance(obj, list) and len(obj) > 0: # handle cutouts and measurements
if hasattr(obj[0], 'save_list'):
raise ValueError("AFTER CUTOUTS IS NO LONGER A LIST, SHOULD NEVER GET HERE!")
Expand Down Expand Up @@ -1572,14 +1569,14 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False,
if self.detections is not None:
more_products = 'detections'
if self.cutouts is not None:
self.cutouts.sources = self.detections # DOUBLE CHECK - WILL THERE ONLY EVER BE ONE CUTOUTS?
self.cutouts.sources = self.detections
self.cutouts = session.merge(self.cutouts)
more_products += ', cutouts'

if self.measurements is not None:
for i, m in enumerate(self.measurements):
# use the new, merged cutouts
self.measurements[i].cutouts = self.cutouts # only one now
self.measurements[i].cutouts = self.cutouts
self.measurements[i].associate_object(session)
self.measurements[i] = session.merge(self.measurements[i])
self.measurements[i].object.measurements.append(self.measurements[i])
Expand Down
18 changes: 4 additions & 14 deletions pipeline/measuring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from scipy import signal

from astropy.table import Table

from improc.photometry import iterative_cutouts_photometry
from improc.tools import make_gaussian

from astropy.table import Table

from models.cutouts import Cutouts
from models.measurements import Measurements
from models.enums_and_bitflags import BitFlagConverter, BadnessConverter
Expand Down Expand Up @@ -214,17 +214,8 @@ def run(self, *args, **kwargs):
m.index_in_sources = int(key[13:]) # grab just the number from "source_index_xxx"
m.best_aperture = cutouts.sources.best_aper_num

# get all the information that used to be populated in cutting
# QUESTION: as far as I can tell, this was never rounded before but somehow caused
# no errors in sqlalchemy, despite being an INT column in the schema??
m.x = cutouts.sources.x[m.index_in_sources] # These will be rounded by Measurements.__setattr__
m.y = cutouts.sources.y[m.index_in_sources]
m.source_row = dict(Table(detections.data)[m.index_in_sources]) # move to measurements probably
for key, value in m.source_row.items():
if isinstance(value, np.number):
m.source_row[key] = value.item() # convert numpy number to python primitive
# m.ra = m.source_row['ra'] # done in one line below
# m.dec = m.source_row['dec']

m.aper_radii = cutouts.sources.image.new_image.zp.aper_cor_radii # zero point corrected aperture radii

Expand Down Expand Up @@ -267,9 +258,8 @@ def run(self, *args, **kwargs):
x = m.x + m.offset_x
y = m.y + m.offset_y
ra, dec = m.cutouts.sources.image.new_image.wcs.wcs.pixel_to_world_values(x, y)
# STRONGLY review this in diff to figure out why I did this source row thing
m.ra = float(ra) # + m.source_row['ra'] # I think this was just wrong
m.dec = float(dec) # + m.source_row['dec'] # I think this was just wrong
m.ra = float(ra)
m.dec = float(dec)
m.calculate_coordinates()

# PSF photometry:
Expand Down
3 changes: 1 addition & 2 deletions tests/fixtures/pipeline_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,7 @@ def make_datastore(
cache_name = os.path.join(cache_dir, cache_sub_name + f'.cutouts_{prov.id[:6]}.h5')
if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_name) ):
SCLogger.debug('loading cutouts from cache. ')
ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name) # this grabs the whole co_list
# even before load()...that ok?
ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name)
ds.cutouts.load()
setattr(ds.cutouts, 'provenance', prov)
setattr(ds.cutouts, 'sources', ds.detections)
Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def sim_sub_image_list(

with SmartSession() as session:
for sub in sub_images:
# breakpoint()
sub.delete_from_disk_and_database(session=session, commit=False, remove_downstreams=True)
session.commit()

Expand Down
9 changes: 4 additions & 5 deletions tests/models/test_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_make_save_load_cutouts(decam_detection_list, cutter):
co_subdict = ds.cutouts.co_dict[subdict_key]

assert ds.cutouts.sub_image == decam_detection_list.image
assert ds.cutouts.sub_image == decam_detection_list.image
assert ds.cutouts.sub_image == decam_detection_list.image
assert ds.cutouts.ref_image == decam_detection_list.image.ref_aligned_image
assert ds.cutouts.new_image == decam_detection_list.image.new_aligned_image

assert isinstance(co_subdict["sub_data"], np.ndarray)
assert isinstance(co_subdict["sub_weight"], np.ndarray)
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_make_save_load_cutouts(decam_detection_list, cutter):
assert np.array_equal(co_subdict.get(f'{im}_{att}'),
co_subdict2.get(f'{im}_{att}'))

assert c2.bitflag == 0 # should not load all columns from file
assert c2.bitflag == 0 # should not load all column data from file

# change the value of one of the arrays
ds.cutouts.co_dict[subdict_key]['sub_data'][0, 0] = 100
Expand Down Expand Up @@ -107,5 +107,4 @@ def test_make_save_load_cutouts(decam_detection_list, cutter):

finally:
if 'ds' in locals() and ds.cutouts is not None:
ds.cutouts.remove_data_from_disk()
ds.cutouts.delete_from_archive()
ds.cutouts.delete_from_disk_and_database()
4 changes: 2 additions & 2 deletions tests/models/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_measurements_attributes(measurer, ptf_datastore, test_config):
aper_radii = test_config.value('extraction.sources.apertures')
ds = measurer.run(ptf_datastore.cutouts)
# check that the measurer actually loaded the measurements from db, and not recalculated
# assert len(ds.measurements) <= len(ds.cutouts) # not all cutouts have saved measurements
assert len(ds.measurements) <= len(ds.cutouts.co_dict) # not all cutouts have saved measurements
assert len(ds.measurements) == len(ptf_datastore.measurements)
assert ds.measurements[0].from_db
assert not measurer.has_recalculated
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_deletion_thresh_is_non_critical(ptf_datastore, measurer):


def test_measurements_forced_photometry(ptf_datastore):
offset_max = 2
offset_max = 2.0
for m in ptf_datastore.measurements:
if abs(m.offset_x) < offset_max and abs(m.offset_y) < offset_max:
break
Expand Down
23 changes: 1 addition & 22 deletions tests/pipeline/test_measuring.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@pytest.mark.flaky(max_runs=3)
def test_measuring_xyz(measurer, decam_cutouts, decam_default_calibrators):
def test_measuring(measurer, decam_cutouts, decam_default_calibrators):
measurer.pars.test_parameter = uuid.uuid4().hex
measurer.pars.bad_pixel_exclude = ['saturated'] # ignore saturated pixels
measurer.pars.bad_flag_exclude = ['satellite'] # ignore satellite cutouts
Expand Down Expand Up @@ -80,13 +80,6 @@ def test_measuring_xyz(measurer, decam_cutouts, decam_default_calibrators):
decam_cutouts.co_dict[f"source_index_11"]["sub_data"] *= 1000
decam_cutouts.co_dict[f"source_index_11"]["sub_data"] += np.random.normal(0, 1, size=sz)

# PROBLEM: individual cutouts do not track badness now that they are in this list
# # a regular cutout but we'll put some bad flag on the cutout
# decam_cutouts[12].badness = 'cosmic ray'

# # a regular cutout with a bad flag that we are ignoring:
# decam_cutouts[13].badness = 'satellite'

# run the measurer
ds = measurer.run(decam_cutouts)

Expand Down Expand Up @@ -225,20 +218,6 @@ def test_measuring_xyz(measurer, decam_cutouts, decam_default_calibrators):
assert m.bkg_std < 3.0


def test_propagate_badness(decam_datastore):
ds = decam_datastore
with SmartSession() as session:
ds.measurements[0].badness = 'cosmic ray'
# find the index of the cutout that corresponds to the measurement
# idx = [i for i, c in enumerate(ds.cutouts) if c.id == ds.measurements[0].cutouts_id][0]
# idx = ds.measurements[0].index_in_sources
# ds.cutouts.co_dict[f"source_index_{idx}"].badness = 'cosmic ray'
# ds.cutouts[idx].update_downstream_badness(session=session)
m = session.merge(ds.measurements[0])

assert m.badness == 'cosmic ray' # note that this does not change disqualifier_scores!


def test_warnings_and_exceptions(decam_datastore, measurer):
measurer.pars.inject_warnings = 1

Expand Down

0 comments on commit 83f13ad

Please sign in to comment.