From 83f13ad0ba6703a6a64ec4831363df40ba31830d Mon Sep 17 00:00:00 2001 From: whohensee <106775295+whohensee@users.noreply.github.com> Date: Thu, 27 Jun 2024 12:52:13 -0700 Subject: [PATCH] Minor changes from reviewing --- models/cutouts.py | 7 +++---- models/measurements.py | 3 --- pipeline/cutting.py | 3 +-- pipeline/data_store.py | 15 ++++++--------- pipeline/measuring.py | 18 ++++-------------- tests/fixtures/pipeline_objects.py | 3 +-- tests/fixtures/simulated.py | 1 - tests/models/test_cutouts.py | 9 ++++----- tests/models/test_measurements.py | 4 ++-- tests/pipeline/test_measuring.py | 23 +---------------------- 10 files changed, 22 insertions(+), 64 deletions(-) diff --git a/models/cutouts.py b/models/cutouts.py index c692d401..5210918c 100644 --- a/models/cutouts.py +++ b/models/cutouts.py @@ -38,16 +38,15 @@ def __init__(self, *args, **kwargs): def __getitem__(self, key): if key not in self.keys(): - # check if the key exists on disk if self.cutouts.filepath is not None: - self.cutouts.load_one_co_dict(key) + self.cutouts.load_one_co_dict(key) # no change if not found return super().__getitem__(key) class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, HasBitFlagBadness): __tablename__ = 'cutouts' - # a unique constraint on the provenance and the source list, but also on the index in the list + # a unique constraint on the provenance and the source list __table_args__ = ( UniqueConstraint( 'sources_id', 'provenance_id', name='_cutouts_sources_provenance_uc' @@ -190,7 +189,7 @@ def __repr__(self): ) @staticmethod - def get_data_dict_attributes(include_optional=True): # WHPR could rename get_data_attributes + def get_data_dict_attributes(include_optional=True): names = [] for im in ['sub', 'ref', 'new']: for att in ['data', 'weight', 'flags']: diff --git a/models/measurements.py b/models/measurements.py index 5020c48c..6625e3dc 100644 --- a/models/measurements.py +++ b/models/measurements.py @@ -405,9 +405,6 @@ def __setattr__(self, key, value): super().__setattr__(key, value) - # figure out if we need to include optional (probably yes) - # revisit after deciding below question, as I think optional - # are never used ATM def get_data_from_cutouts(self): """Populates this object with the cutout data arrays used in calculations. This allows us to use, for example, self.sub_data diff --git a/pipeline/cutting.py b/pipeline/cutting.py index 90d1e31c..b98affeb 100644 --- a/pipeline/cutting.py +++ b/pipeline/cutting.py @@ -77,7 +77,7 @@ def run(self, *args, **kwargs): if cutouts is not None: cutouts.load_all_co_data() - if cutouts is None or cutouts.co_dict == {}: + if cutouts is None or len(cutouts.co_dict) == 0: self.has_recalculated = True # use the latest source list in the data store, @@ -119,7 +119,6 @@ def run(self, *args, **kwargs): cutouts._upstream_bitflag = 0 cutouts._upstream_bitflag |= detections.bitflag - cutouts.co_dict = {} for i, source in enumerate(detections.data): data_dict = {} data_dict["sub_data"] = sub_stamps_data[i] diff --git a/pipeline/data_store.py b/pipeline/data_store.py index 769cf62c..f5425300 100644 --- a/pipeline/data_store.py +++ b/pipeline/data_store.py @@ -1285,7 +1285,7 @@ def get_cutouts(self, provenance=None, session=None): if self.cutouts is not None: self.cutouts.load_all_co_data() - if self.cutouts.co_dict == {}: + if len(self.cutouts.co_dict) == 0: self.cutouts = None # TODO: what about images that actually don't have any detections? # make sure the cutouts have the correct provenance @@ -1315,9 +1315,6 @@ def get_cutouts(self, provenance=None, session=None): Cutouts.provenance_id == provenance.id, ) ).first() - # cutouts has a unique constraint with sources_id and provenance_id - # so a check I wrote when using all() that there was only 1 - # is totally redundant, i think return self.cutouts @@ -1361,11 +1358,10 @@ def get_measurements(self, provenance=None, session=None): if self.measurements is None: with SmartSession(session, self.session) as session: cutouts = self.get_cutouts(session=session) - cutout_ids = [cutouts.id] # WHPR this is inelegant. It works, but should fix self.measurements = session.scalars( sa.select(Measurements).where( - Measurements.cutouts_id.in_(cutout_ids), + Measurements.cutouts_id == cutouts.id, Measurements.provenance_id == provenance.id, ) ).all() @@ -1486,7 +1482,8 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False, if obj is None: continue - # TODO need to change this as cutouts changes to just be a cutoutsfile + # QUESTION: This whole block can never be reached, as cutouts is not a list and measurements + # don't save to disk. I want to kill it all. Objections? if isinstance(obj, list) and len(obj) > 0: # handle cutouts and measurements if hasattr(obj[0], 'save_list'): raise ValueError("AFTER CUTOUTS IS NO LONGER A LIST, SHOULD NEVER GET HERE!") @@ -1572,14 +1569,14 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False, if self.detections is not None: more_products = 'detections' if self.cutouts is not None: - self.cutouts.sources = self.detections # DOUBLE CHECK - WILL THERE ONLY EVER BE ONE CUTOUTS? + self.cutouts.sources = self.detections self.cutouts = session.merge(self.cutouts) more_products += ', cutouts' if self.measurements is not None: for i, m in enumerate(self.measurements): # use the new, merged cutouts - self.measurements[i].cutouts = self.cutouts # only one now + self.measurements[i].cutouts = self.cutouts self.measurements[i].associate_object(session) self.measurements[i] = session.merge(self.measurements[i]) self.measurements[i].object.measurements.append(self.measurements[i]) diff --git a/pipeline/measuring.py b/pipeline/measuring.py index 53b5bafa..7e66ba8e 100644 --- a/pipeline/measuring.py +++ b/pipeline/measuring.py @@ -5,11 +5,11 @@ from scipy import signal +from astropy.table import Table + from improc.photometry import iterative_cutouts_photometry from improc.tools import make_gaussian -from astropy.table import Table - from models.cutouts import Cutouts from models.measurements import Measurements from models.enums_and_bitflags import BitFlagConverter, BadnessConverter @@ -214,17 +214,8 @@ def run(self, *args, **kwargs): m.index_in_sources = int(key[13:]) # grab just the number from "source_index_xxx" m.best_aperture = cutouts.sources.best_aper_num - # get all the information that used to be populated in cutting - # QUESTION: as far as I can tell, this was never rounded before but somehow caused - # no errors in sqlalchemy, despite being an INT column in the schema?? m.x = cutouts.sources.x[m.index_in_sources] # These will be rounded by Measurements.__setattr__ m.y = cutouts.sources.y[m.index_in_sources] - m.source_row = dict(Table(detections.data)[m.index_in_sources]) # move to measurements probably - for key, value in m.source_row.items(): - if isinstance(value, np.number): - m.source_row[key] = value.item() # convert numpy number to python primitive - # m.ra = m.source_row['ra'] # done in one line below - # m.dec = m.source_row['dec'] m.aper_radii = cutouts.sources.image.new_image.zp.aper_cor_radii # zero point corrected aperture radii @@ -267,9 +258,8 @@ def run(self, *args, **kwargs): x = m.x + m.offset_x y = m.y + m.offset_y ra, dec = m.cutouts.sources.image.new_image.wcs.wcs.pixel_to_world_values(x, y) - # STRONGLY review this in diff to figure out why I did this source row thing - m.ra = float(ra) # + m.source_row['ra'] # I think this was just wrong - m.dec = float(dec) # + m.source_row['dec'] # I think this was just wrong + m.ra = float(ra) + m.dec = float(dec) m.calculate_coordinates() # PSF photometry: diff --git a/tests/fixtures/pipeline_objects.py b/tests/fixtures/pipeline_objects.py index 5af3aad5..684ca28a 100644 --- a/tests/fixtures/pipeline_objects.py +++ b/tests/fixtures/pipeline_objects.py @@ -900,8 +900,7 @@ def make_datastore( cache_name = os.path.join(cache_dir, cache_sub_name + f'.cutouts_{prov.id[:6]}.h5') if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_name) ): SCLogger.debug('loading cutouts from cache. ') - ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name) # this grabs the whole co_list - # even before load()...that ok? + ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name) ds.cutouts.load() setattr(ds.cutouts, 'provenance', prov) setattr(ds.cutouts, 'sources', ds.detections) diff --git a/tests/fixtures/simulated.py b/tests/fixtures/simulated.py index 2408808f..228a1b24 100644 --- a/tests/fixtures/simulated.py +++ b/tests/fixtures/simulated.py @@ -620,7 +620,6 @@ def sim_sub_image_list( with SmartSession() as session: for sub in sub_images: - # breakpoint() sub.delete_from_disk_and_database(session=session, commit=False, remove_downstreams=True) session.commit() diff --git a/tests/models/test_cutouts.py b/tests/models/test_cutouts.py index 78ce28e2..39964eee 100644 --- a/tests/models/test_cutouts.py +++ b/tests/models/test_cutouts.py @@ -24,8 +24,8 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): co_subdict = ds.cutouts.co_dict[subdict_key] assert ds.cutouts.sub_image == decam_detection_list.image - assert ds.cutouts.sub_image == decam_detection_list.image - assert ds.cutouts.sub_image == decam_detection_list.image + assert ds.cutouts.ref_image == decam_detection_list.image.ref_aligned_image + assert ds.cutouts.new_image == decam_detection_list.image.new_aligned_image assert isinstance(co_subdict["sub_data"], np.ndarray) assert isinstance(co_subdict["sub_weight"], np.ndarray) @@ -66,7 +66,7 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): assert np.array_equal(co_subdict.get(f'{im}_{att}'), co_subdict2.get(f'{im}_{att}')) - assert c2.bitflag == 0 # should not load all columns from file + assert c2.bitflag == 0 # should not load all column data from file # change the value of one of the arrays ds.cutouts.co_dict[subdict_key]['sub_data'][0, 0] = 100 @@ -107,5 +107,4 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): finally: if 'ds' in locals() and ds.cutouts is not None: - ds.cutouts.remove_data_from_disk() - ds.cutouts.delete_from_archive() + ds.cutouts.delete_from_disk_and_database() diff --git a/tests/models/test_measurements.py b/tests/models/test_measurements.py index fbf2b2e6..8eb1db76 100644 --- a/tests/models/test_measurements.py +++ b/tests/models/test_measurements.py @@ -23,7 +23,7 @@ def test_measurements_attributes(measurer, ptf_datastore, test_config): aper_radii = test_config.value('extraction.sources.apertures') ds = measurer.run(ptf_datastore.cutouts) # check that the measurer actually loaded the measurements from db, and not recalculated - # assert len(ds.measurements) <= len(ds.cutouts) # not all cutouts have saved measurements + assert len(ds.measurements) <= len(ds.cutouts.co_dict) # not all cutouts have saved measurements assert len(ds.measurements) == len(ptf_datastore.measurements) assert ds.measurements[0].from_db assert not measurer.has_recalculated @@ -293,7 +293,7 @@ def test_deletion_thresh_is_non_critical(ptf_datastore, measurer): def test_measurements_forced_photometry(ptf_datastore): - offset_max = 2 + offset_max = 2.0 for m in ptf_datastore.measurements: if abs(m.offset_x) < offset_max and abs(m.offset_y) < offset_max: break diff --git a/tests/pipeline/test_measuring.py b/tests/pipeline/test_measuring.py index 8e22a9b5..0b1a6487 100644 --- a/tests/pipeline/test_measuring.py +++ b/tests/pipeline/test_measuring.py @@ -11,7 +11,7 @@ @pytest.mark.flaky(max_runs=3) -def test_measuring_xyz(measurer, decam_cutouts, decam_default_calibrators): +def test_measuring(measurer, decam_cutouts, decam_default_calibrators): measurer.pars.test_parameter = uuid.uuid4().hex measurer.pars.bad_pixel_exclude = ['saturated'] # ignore saturated pixels measurer.pars.bad_flag_exclude = ['satellite'] # ignore satellite cutouts @@ -80,13 +80,6 @@ def test_measuring_xyz(measurer, decam_cutouts, decam_default_calibrators): decam_cutouts.co_dict[f"source_index_11"]["sub_data"] *= 1000 decam_cutouts.co_dict[f"source_index_11"]["sub_data"] += np.random.normal(0, 1, size=sz) - # PROBLEM: individual cutouts do not track badness now that they are in this list - # # a regular cutout but we'll put some bad flag on the cutout - # decam_cutouts[12].badness = 'cosmic ray' - - # # a regular cutout with a bad flag that we are ignoring: - # decam_cutouts[13].badness = 'satellite' - # run the measurer ds = measurer.run(decam_cutouts) @@ -225,20 +218,6 @@ def test_measuring_xyz(measurer, decam_cutouts, decam_default_calibrators): assert m.bkg_std < 3.0 -def test_propagate_badness(decam_datastore): - ds = decam_datastore - with SmartSession() as session: - ds.measurements[0].badness = 'cosmic ray' - # find the index of the cutout that corresponds to the measurement - # idx = [i for i, c in enumerate(ds.cutouts) if c.id == ds.measurements[0].cutouts_id][0] - # idx = ds.measurements[0].index_in_sources - # ds.cutouts.co_dict[f"source_index_{idx}"].badness = 'cosmic ray' - # ds.cutouts[idx].update_downstream_badness(session=session) - m = session.merge(ds.measurements[0]) - - assert m.badness == 'cosmic ray' # note that this does not change disqualifier_scores! - - def test_warnings_and_exceptions(decam_datastore, measurer): measurer.pars.inject_warnings = 1