From 0634f269c400dcb8f6e63bae3c9ddd440b65459d Mon Sep 17 00:00:00 2001 From: Rob Knop Date: Wed, 26 Jun 2024 06:53:43 -0700 Subject: [PATCH 1/3] Don't hold a session open throughout all of subtraction (#321) --- .github/workflows/run-improc-tests.yml | 6 ++ .github/workflows/run-model-tests-1.yml | 5 ++ .github/workflows/run-model-tests-2.yml | 5 ++ .github/workflows/run-pipeline-tests-1.yml | 5 ++ .github/workflows/run-pipeline-tests-2.yml | 5 ++ .github/workflows/run-util-tests.yml | 6 ++ pipeline/subtraction.py | 76 ++++++++++--------- tests/fixtures/pipeline_objects.py | 14 +++- tests/models/test_image.py | 3 +- tests/models/test_psf.py | 3 +- tests/models/test_source_list.py | 3 +- tests/webap_secrets/seechange_webap_config.py | 1 + 12 files changed, 90 insertions(+), 42 deletions(-) diff --git a/.github/workflows/run-improc-tests.yml b/.github/workflows/run-improc-tests.yml index 66ab002c..ba39f935 100644 --- a/.github/workflows/run-improc-tests.yml +++ b/.github/workflows/run-improc-tests.yml @@ -59,4 +59,10 @@ jobs: - name: run test run: | + # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=tests/improc docker compose run runtests diff --git a/.github/workflows/run-model-tests-1.yml b/.github/workflows/run-model-tests-1.yml index a7487536..fb610eee 100644 --- a/.github/workflows/run-model-tests-1.yml +++ b/.github/workflows/run-model-tests-1.yml @@ -59,5 +59,10 @@ jobs: - name: run test run: | + # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" shopt -s nullglob TEST_SUBFOLDER=$(ls tests/models/test_{a..l}*.py) docker compose run runtests diff --git a/.github/workflows/run-model-tests-2.yml b/.github/workflows/run-model-tests-2.yml index c2d0eace..3158b7ba 100644 --- a/.github/workflows/run-model-tests-2.yml +++ b/.github/workflows/run-model-tests-2.yml @@ -59,5 +59,10 @@ jobs: - name: run test run: | + # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" shopt -s nullglob TEST_SUBFOLDER=$(ls tests/models/test_{m..z}*.py) docker compose run runtests diff --git a/.github/workflows/run-pipeline-tests-1.yml b/.github/workflows/run-pipeline-tests-1.yml index 38132c7e..702fc61e 100644 --- a/.github/workflows/run-pipeline-tests-1.yml +++ b/.github/workflows/run-pipeline-tests-1.yml @@ -59,5 +59,10 @@ jobs: - name: run test run: | + # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" shopt -s nullglob TEST_SUBFOLDER=$(ls tests/pipeline/test_{a..o}*.py) docker compose run runtests diff --git a/.github/workflows/run-pipeline-tests-2.yml b/.github/workflows/run-pipeline-tests-2.yml index a94c2422..461739d8 100644 --- a/.github/workflows/run-pipeline-tests-2.yml +++ b/.github/workflows/run-pipeline-tests-2.yml @@ -59,5 +59,10 @@ jobs: - name: run test run: | + # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" shopt -s nullglob TEST_SUBFOLDER=$(ls tests/pipeline/test_{p..z}*.py) docker compose run runtests diff --git a/.github/workflows/run-util-tests.yml b/.github/workflows/run-util-tests.yml index 7c626aeb..591bc250 100644 --- a/.github/workflows/run-util-tests.yml +++ b/.github/workflows/run-util-tests.yml @@ -59,4 +59,10 @@ jobs: - name: run test run: | + # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=tests/util docker compose run runtests diff --git a/pipeline/subtraction.py b/pipeline/subtraction.py index f66318bc..dd854486 100644 --- a/pipeline/subtraction.py +++ b/pipeline/subtraction.py @@ -268,43 +268,49 @@ def run(self, *args, **kwargs): sub_image.provenance_id = prov.id sub_image.coordinates_to_alignment_target() # make sure the WCS is aligned to the correct image - # make sure to grab the correct aligned images - new_image = [im for im in sub_image.aligned_images if im.mjd == sub_image.new_image.mjd] - if len(new_image) != 1: - raise ValueError('Cannot find the new image in the aligned images') - new_image = new_image[0] - - ref_image = [im for im in sub_image.aligned_images if im.mjd == sub_image.ref_image.mjd] - if len(ref_image) != 1: - raise ValueError('Cannot find the reference image in the aligned images') - ref_image = ref_image[0] - - if self.pars.method == 'naive': - outdict = self._subtract_naive(new_image, ref_image) - elif self.pars.method == 'hotpants': - outdict = self._subtract_hotpants(new_image, ref_image) - elif self.pars.method == 'zogy': - outdict = self._subtract_zogy(new_image, ref_image) - else: - raise ValueError(f'Unknown subtraction method {self.pars.method}') - - sub_image.data = outdict['outim'] - sub_image.weight = outdict['outwt'] - sub_image.flags = outdict['outfl'] - if 'score' in outdict: - sub_image.score = outdict['score'] - if 'alpha' in outdict: + # Need to make sure the upstream images are loaded into this session before + # we disconnect it from the database. (We don't want to hold the database + # connection open through all the slow processes below.) + upstream_images = sub_image.upstream_images + + if self.has_recalculated: + # make sure to grab the correct aligned images + new_image = [im for im in sub_image.aligned_images if im.mjd == sub_image.new_image.mjd] + if len(new_image) != 1: + raise ValueError('Cannot find the new image in the aligned images') + new_image = new_image[0] + + ref_image = [im for im in sub_image.aligned_images if im.mjd == sub_image.ref_image.mjd] + if len(ref_image) != 1: + raise ValueError('Cannot find the reference image in the aligned images') + ref_image = ref_image[0] + + if self.pars.method == 'naive': + outdict = self._subtract_naive(new_image, ref_image) + elif self.pars.method == 'hotpants': + outdict = self._subtract_hotpants(new_image, ref_image) + elif self.pars.method == 'zogy': + outdict = self._subtract_zogy(new_image, ref_image) + else: + raise ValueError(f'Unknown subtraction method {self.pars.method}') + + sub_image.data = outdict['outim'] + sub_image.weight = outdict['outwt'] + sub_image.flags = outdict['outfl'] + if 'score' in outdict: + sub_image.score = outdict['score'] + if 'alpha' in outdict: + sub_image.psfflux = outdict['alpha'] + if 'alpha_err' in outdict: + sub_image.psffluxerr = outdict['alpha_err'] + if 'psf' in outdict: + # TODO: clip the array to be a cutout around the PSF, right now it is same shape as image! + sub_image.zogy_psf = outdict['psf'] # not saved, can be useful for testing / source detection + if 'alpha' in outdict and 'alpha_err' in outdict: sub_image.psfflux = outdict['alpha'] - if 'alpha_err' in outdict: sub_image.psffluxerr = outdict['alpha_err'] - if 'psf' in outdict: - # TODO: clip the array to be a cutout around the PSF, right now it is same shape as image! - sub_image.zogy_psf = outdict['psf'] # not saved, can be useful for testing / source detection - if 'alpha' in outdict and 'alpha_err' in outdict: - sub_image.psfflux = outdict['alpha'] - sub_image.psffluxerr = outdict['alpha_err'] - - sub_image.subtraction_output = outdict # save the full output for debugging + + sub_image.subtraction_output = outdict # save the full output for debugging if sub_image._upstream_bitflag is None: sub_image._upstream_bitflag = 0 diff --git a/tests/fixtures/pipeline_objects.py b/tests/fixtures/pipeline_objects.py index dd48ddae..fa4c2269 100644 --- a/tests/fixtures/pipeline_objects.py +++ b/tests/fixtures/pipeline_objects.py @@ -369,7 +369,7 @@ def make_datastore( code_version = args[0].provenance.code_version ds = DataStore(*args) # make a new datastore - if ( cache_dir is not None ) and ( cache_base_name is not None ) and ( not os.getenv( "LIMIT_CACHE_USE" ) ): + if ( cache_dir is not None ) and ( cache_base_name is not None ) and ( not os.getenv( "LIMIT_CACHE_USAGE" ) ): ds.cache_base_name = os.path.join(cache_dir, cache_base_name) # save this for testing purposes p = pipeline_factory() @@ -691,13 +691,17 @@ def make_datastore( ds = p.extractor.run(ds, session) ds.sources.save(overwrite=True) - if cache_dir is not None and cache_base_name is not None: + if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and + ( cache_dir is not None ) and ( cache_base_name is not None ) + ): output_path = copy_to_cache(ds.sources, cache_dir) if cache_dir is not None and cache_base_name is not None and output_path != sources_cache_path: warnings.warn(f'cache path {sources_cache_path} does not match output path {output_path}') ds.psf.save(overwrite=True) - if cache_dir is not None and cache_base_name is not None: + if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and + ( cache_dir is not None ) and ( cache_base_name is not None ) + ): output_path = copy_to_cache(ds.psf, cache_dir) if cache_dir is not None and cache_base_name is not None and output_path != psf_cache_path: warnings.warn(f'cache path {psf_cache_path} does not match output path {output_path}') @@ -706,7 +710,9 @@ def make_datastore( ds = p.backgrounder.run(ds, session) ds.bg.save(overwrite=True) - if cache_dir is not None and cache_base_name is not None: + if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and + ( cache_dir is not None ) and ( cache_base_name is not None ) + ): output_path = copy_to_cache(ds.bg, cache_dir) if cache_dir is not None and cache_base_name is not None and output_path != bg_cache_path: warnings.warn(f'cache path {bg_cache_path} does not match output path {output_path}') diff --git a/tests/models/test_image.py b/tests/models/test_image.py index f74f151d..882ac708 100644 --- a/tests/models/test_image.py +++ b/tests/models/test_image.py @@ -1387,7 +1387,8 @@ def test_image_products_are_deleted(ptf_datastore, data_dir, archive): assert not os.path.isfile(file) -@pytest.mark.flaky(max_runs=3) +# @pytest.mark.flaky(max_runs=3) +@pytest.mark.skip(reason="We aren't succeeding at controlling garbage collection") def test_free( decam_exposure, decam_raw_image, ptf_ref ): proc = psutil.Process() origmem = proc.memory_info() diff --git a/tests/models/test_psf.py b/tests/models/test_psf.py index 2138f2f0..7547e264 100644 --- a/tests/models/test_psf.py +++ b/tests/models/test_psf.py @@ -344,7 +344,8 @@ def test_save_psf( ztf_datastore_uncommitted, provenance_base, provenance_extra im.delete_from_disk_and_database(session=session) -@pytest.mark.flaky(max_runs=3) +# @pytest.mark.flaky(max_runs=3) +@pytest.mark.skip(reason="We aren't succeeding at controlling garbage collection") def test_free( decam_datastore ): ds = decam_datastore ds.get_psf() diff --git a/tests/models/test_source_list.py b/tests/models/test_source_list.py index a355edec..36eef87e 100644 --- a/tests/models/test_source_list.py +++ b/tests/models/test_source_list.py @@ -269,7 +269,8 @@ def test_calc_apercor( decam_datastore ): # assert sources.calc_aper_cor( aper_num=2, inf_aper_num=7 ) == pytest.approx( -0.024, abs=0.001 ) -@pytest.mark.flaky(max_runs=3) +# @pytest.mark.flaky(max_runs=3) +@pytest.mark.skip(reason="We aren't succeeding at controlling garbage collection") def test_free( decam_datastore ): ds = decam_datastore ds.get_sources() diff --git a/tests/webap_secrets/seechange_webap_config.py b/tests/webap_secrets/seechange_webap_config.py index 0539807b..6a0e5e99 100644 --- a/tests/webap_secrets/seechange_webap_config.py +++ b/tests/webap_secrets/seechange_webap_config.py @@ -1,3 +1,4 @@ +import pathlib PG_HOST = 'seechange_postgres' PG_PORT = 5432 PG_USER = 'postgres' From 0cbe1ef1eca6e5ee1d8d62643d7a2fba989c3906 Mon Sep 17 00:00:00 2001 From: whohensee <106775295+whohensee@users.noreply.github.com> Date: Fri, 28 Jun 2024 14:33:12 -0700 Subject: [PATCH 2/3] Rework Cutouts and Measurements (#302) Changed the data-product structure so that there is one Cutouts per SourceList, which has an attribute `co_dict` that is a dictionary (actually a class that inherits from dictionary) of dictionaries containing data for all cutouts used by measurements. This data is saved and loaded from disk using the filepath attribute which is saved to the database. Added attributes to Measurements that can access this data (eg. `Measurements.sub_data`) through the Cutouts relationship, loading the relevant data when needed from disk. --- ...6d07485_rework_cutouts_and_measurements.py | 66 ++ models/base.py | 5 +- models/cutouts.py | 601 ++++-------------- models/enums_and_bitflags.py | 8 +- models/measurements.py | 185 +++++- models/source_list.py | 4 +- pipeline/cutting.py | 69 +- pipeline/data_store.py | 36 +- pipeline/measuring.py | 57 +- tests/fixtures/pipeline_objects.py | 18 +- tests/fixtures/simulated.py | 2 +- tests/models/test_cutouts.py | 135 ++-- tests/models/test_image.py | 1 - tests/models/test_measurements.py | 2 +- tests/models/test_objects.py | 4 +- tests/models/test_ptf.py | 2 +- tests/pipeline/test_measuring.py | 118 ++-- tests/pipeline/test_pipeline.py | 61 +- 18 files changed, 562 insertions(+), 812 deletions(-) create mode 100644 alembic/versions/2024_06_28_1757-7384c6d07485_rework_cutouts_and_measurements.py diff --git a/alembic/versions/2024_06_28_1757-7384c6d07485_rework_cutouts_and_measurements.py b/alembic/versions/2024_06_28_1757-7384c6d07485_rework_cutouts_and_measurements.py new file mode 100644 index 00000000..27cec09a --- /dev/null +++ b/alembic/versions/2024_06_28_1757-7384c6d07485_rework_cutouts_and_measurements.py @@ -0,0 +1,66 @@ +"""rework cutouts and measurements + +Revision ID: 7384c6d07485 +Revises: a375526c8260 +Create Date: 2024-06-28 17:57:44.173607 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '7384c6d07485' +down_revision = 'a375526c8260' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('_cutouts_index_sources_provenance_uc', 'cutouts', type_='unique') + op.drop_index('ix_cutouts_ecllat', table_name='cutouts') + op.drop_index('ix_cutouts_gallat', table_name='cutouts') + op.drop_index('ix_cutouts_filepath', table_name='cutouts') + op.create_index(op.f('ix_cutouts_filepath'), 'cutouts', ['filepath'], unique=True) + op.create_unique_constraint('_cutouts_sources_provenance_uc', 'cutouts', ['sources_id', 'provenance_id']) + op.drop_column('cutouts', 'ecllon') + op.drop_column('cutouts', 'ra') + op.drop_column('cutouts', 'gallat') + op.drop_column('cutouts', 'index_in_sources') + op.drop_column('cutouts', 'y') + op.drop_column('cutouts', 'gallon') + op.drop_column('cutouts', 'dec') + op.drop_column('cutouts', 'x') + op.drop_column('cutouts', 'ecllat') + op.add_column('measurements', sa.Column('index_in_sources', sa.Integer(), nullable=False)) + op.add_column('measurements', sa.Column('center_x_pixel', sa.Integer(), nullable=False)) + op.add_column('measurements', sa.Column('center_y_pixel', sa.Integer(), nullable=False)) + op.drop_constraint('_measurements_cutouts_provenance_uc', 'measurements', type_='unique') + op.create_unique_constraint('_measurements_cutouts_provenance_uc', 'measurements', ['cutouts_id', 'index_in_sources', 'provenance_id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('_measurements_cutouts_provenance_uc', 'measurements', type_='unique') + op.create_unique_constraint('_measurements_cutouts_provenance_uc', 'measurements', ['cutouts_id', 'provenance_id']) + op.drop_column('measurements', 'center_y_pixel') + op.drop_column('measurements', 'center_x_pixel') + op.drop_column('measurements', 'index_in_sources') + op.add_column('cutouts', sa.Column('ecllat', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True)) + op.add_column('cutouts', sa.Column('x', sa.INTEGER(), autoincrement=False, nullable=False)) + op.add_column('cutouts', sa.Column('dec', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False)) + op.add_column('cutouts', sa.Column('gallon', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True)) + op.add_column('cutouts', sa.Column('y', sa.INTEGER(), autoincrement=False, nullable=False)) + op.add_column('cutouts', sa.Column('index_in_sources', sa.INTEGER(), autoincrement=False, nullable=False)) + op.add_column('cutouts', sa.Column('gallat', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True)) + op.add_column('cutouts', sa.Column('ra', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False)) + op.add_column('cutouts', sa.Column('ecllon', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True)) + op.drop_constraint('_cutouts_sources_provenance_uc', 'cutouts', type_='unique') + op.drop_index(op.f('ix_cutouts_filepath'), table_name='cutouts') + op.create_index('ix_cutouts_filepath', 'cutouts', ['filepath'], unique=False) + op.create_index('ix_cutouts_gallat', 'cutouts', ['gallat'], unique=False) + op.create_index('ix_cutouts_ecllat', 'cutouts', ['ecllat'], unique=False) + op.create_unique_constraint('_cutouts_index_sources_provenance_uc', 'cutouts', ['index_in_sources', 'sources_id', 'provenance_id']) + # ### end Alembic commands ### diff --git a/models/base.py b/models/base.py index f7cc1d70..614d9627 100644 --- a/models/base.py +++ b/models/base.py @@ -671,14 +671,11 @@ def safe_mkdir(cls, path): @declared_attr def filepath(cls): - uniqueness = True - if cls.__name__ in ['Cutouts']: - uniqueness = False return sa.Column( sa.Text, nullable=False, index=True, - unique=uniqueness, + unique=True, doc="Base path (relative to the data root) for a stored file" ) diff --git a/models/cutouts.py b/models/cutouts.py index a4df997c..c0f686de 100644 --- a/models/cutouts.py +++ b/models/cutouts.py @@ -17,21 +17,39 @@ SeeChangeBase, AutoIDMixin, FileOnDiskMixin, - SpatiallyIndexed, HasBitFlagBadness, ) -from models.enums_and_bitflags import CutoutsFormatConverter, cutouts_badness_inverse +from models.enums_and_bitflags import CutoutsFormatConverter from models.source_list import SourceList -class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed, HasBitFlagBadness): +class Co_Dict(dict): + """Cutouts Dictionary used in Cutouts to store dictionaries which hold data arrays + for individual cutouts. Acts as a normal dictionary, except when a key is passed + using bracket notation (such as "co_dict[source_index_7]"), if that key is not present + in the Co_dict then it will search on disk for the requested data, and if found + will silently load that data and return it. + Must be assigned a Cutouts object to its cutouts attribute so that it knows + how to look for data. + """ + def __init__(self, *args, **kwargs): + self.cutouts = None # this must be assigned before use + super().__init__(self, *args, **kwargs) + + def __getitem__(self, key): + if key not in self.keys(): + if self.cutouts.filepath is not None: + self.cutouts.load_one_co_dict(key) # no change if not found + return super().__getitem__(key) + +class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, HasBitFlagBadness): __tablename__ = 'cutouts' - # a unique constraint on the provenance and the source list, but also on the index in the list + # a unique constraint on the provenance and the source list __table_args__ = ( UniqueConstraint( - 'index_in_sources', 'sources_id', 'provenance_id', name='_cutouts_index_sources_provenance_uc' + 'sources_id', 'provenance_id', name='_cutouts_sources_provenance_uc' ), ) @@ -71,27 +89,9 @@ def format(self, value): doc="The source list (of detections in the difference image) this cutouts object is associated with. " ) - index_in_sources = sa.Column( - sa.Integer, - nullable=False, - doc="Index of this cutout in the source list (of detections in the difference image). " - ) - sub_image_id = association_proxy('sources', 'image_id') sub_image = association_proxy('sources', 'image') - x = sa.Column( - sa.Integer, - nullable=False, - doc="X pixel coordinate of the center of the cutout. " - ) - - y = sa.Column( - sa.Integer, - nullable=False, - doc="Y pixel coordinate of the center of the cutout. " - ) - provenance_id = sa.Column( sa.ForeignKey('provenances.id', ondelete="CASCADE", name='cutouts_provenance_id_fkey'), nullable=False, @@ -147,6 +147,9 @@ def __init__(self, *args, **kwargs): self._new_weight = None self._new_flags = None + self.co_dict = Co_Dict() + self.co_dict.cutouts = self + self._bitflag = 0 # manually set all properties (columns or not) @@ -175,24 +178,19 @@ def init_on_load(self): self._new_weight = None self._new_flags = None + self.co_dict = Co_Dict() + self.co_dict.cutouts = self + def __repr__(self): return ( f"" ) - def __setattr__(self, key, value): - if key in ['x', 'y'] and value is not None: - value = int(round(value)) - - super().__setattr__(key, value) - @staticmethod - def get_data_attributes(include_optional=True): - names = ['source_row'] + def get_data_dict_attributes(include_optional=True): + names = [] for im in ['sub', 'ref', 'new']: for att in ['data', 'weight', 'flags']: names.append(f'{im}_{att}') @@ -202,33 +200,24 @@ def get_data_attributes(include_optional=True): return names - @property - def has_data(self): - for att in self.get_data_attributes(include_optional=False): - if getattr(self, att) is None: - return False - return True - - @property - def sub_nandata(self): - if self.sub_data is None or self.sub_flags is None: - return None - return np.where(self.sub_flags > 0, np.nan, self.sub_data) - - @property - def ref_nandata(self): - if self.ref_data is None or self.ref_flags is None: - return None - return np.where(self.ref_flags > 0, np.nan, self.ref_data) + def load_all_co_data(self): + """Intended method for a Cutouts object to ensure that the data for all + sources is loaded into its co_dict attribute. Will only actually load + from disk if any subdictionaries (one per source in SourceList) are missing. - @property - def new_nandata(self): - if self.new_data is None or self.new_flags is None: - return None - return np.where(self.new_flags > 0, np.nan, self.new_data) + Should be used before, for example, iterating over the dictionary as in + the creation of Measurements objects. Not necessary for accessing + individual subdictionaries however, because the Co_Dict class can lazy + load those as they are requested (eg. co_dict["source_index_0"]). + """ + if self.sources.num_sources is None: + raise ValueError("The detections of this cutouts has no num_sources attr") + proper_length = self.sources.num_sources + if len(self.co_dict) != proper_length and self.filepath is not None: + self.load() @staticmethod - def from_detections(detections, source_index, provenance=None, **kwargs): + def from_detections(detections, provenance=None, **kwargs): """Create a Cutout object from a row in the SourceList. The SourceList must have a valid image attribute, and that image should have exactly two @@ -239,8 +228,6 @@ def from_detections(detections, source_index, provenance=None, **kwargs): ---------- detections: SourceList The source list from which to create the cutout. - source_index: int - The index of the source in the source list from which to create the cutout. provenance: Provenance, optional The provenance of the cutout. If not given, will leave as None (to be filled externally). kwargs: dict @@ -255,23 +242,8 @@ def from_detections(detections, source_index, provenance=None, **kwargs): """ cutout = Cutouts() cutout.sources = detections - cutout.index_in_sources = source_index - cutout.source_row = dict(Table(detections.data)[source_index]) - for key, value in cutout.source_row.items(): - if isinstance(value, np.number): - cutout.source_row[key] = value.item() # convert numpy number to python primitive - cutout.x = detections.x[source_index] - cutout.y = detections.y[source_index] - cutout.ra = cutout.source_row['ra'] - cutout.dec = cutout.source_row['dec'] - cutout.calculate_coordinates() cutout.provenance = provenance - # add the data, weight, and flags to the cutout from kwargs - for im in ['sub', 'ref', 'new']: - for att in ['data', 'weight', 'flags']: - setattr(cutout, f'{im}_{att}', kwargs.get(f'{im}_{att}', None)) - # update the bitflag cutout._upstream_bitflag = detections.bitflag @@ -302,11 +274,14 @@ def invent_filepath(self): return filename - def _save_dataset_to_hdf5(self, file, groupname): - """Save the dataset from this Cutouts object into an HDF5 group for an open file. + def _save_dataset_dict_to_hdf5(self, co_subdict, file, groupname): + """Save the one co_subdict from the co_dict of this Cutouts + into an HDF5 group for an open file. Parameters ---------- + co_subdict: dict + The subdict containing the data for a single cutout file: h5py.File The open HDF5 file to save to. groupname: str @@ -315,32 +290,20 @@ def _save_dataset_to_hdf5(self, file, groupname): if groupname in file: del file[groupname] - # handle the data arrays - for att in self.get_data_attributes(): - if att == 'source_row': - continue + for key in self.get_data_dict_attributes(): + data = co_subdict.get(key) - data = getattr(self, f'_{att}') # get the private attribute so as not to trigger a load upon hitting None if data is not None: file.create_dataset( - f'{groupname}/{att}', + f'{groupname}/{key}', data=data, shape=data.shape, dtype=data.dtype, compression='gzip' ) - # handle the source_row dictionary - target = file[groupname].attrs - for key in target.keys(): # first clear the existing keys - del target[key] - - # then add the new ones - for key, value in self.source_row.items(): - target[key] = value - - def save(self, filename=None, **kwargs): - """Save a single Cutouts object into a file. + def save(self, filename=None, overwrite=True, **kwargs): + """Save the data of this Cutouts object into a file. Parameters ---------- @@ -349,105 +312,77 @@ def save(self, filename=None, **kwargs): kwargs: dict Any additional keyword arguments to pass to the FileOnDiskMixin.save method. """ - raise NotImplementedError('Saving only a single cutout into a file is not supported. Use save_list instead.') - - if not self.has_data: - raise RuntimeError("The Cutouts data is not loaded. Cannot save.") - - if filename is not None: - self.filepath = filename - if self.filepath is None: - self.filepath = self.invent_filepath() + if len(self.co_dict) == 0: + return None # do nothing - fullname = self.get_fullpath() - self.safe_mkdir(os.path.dirname(fullname)) - - if self.format == 'hdf5': - with h5py.File(fullname, 'a') as file: - self._save_dataset_to_hdf5(file, f'source_{self.index_in_sources}') - elif self.format == 'fits': - raise NotImplementedError('Saving cutouts to fits is not yet implemented.') - elif self.format in ['jpg', 'png']: - raise NotImplementedError('Saving cutouts to jpg or png is not yet implemented.') - else: - raise TypeError(f"Unable to save cutouts file of type {self.format}") - - # make sure to also save using the FileOnDiskMixin method - super().save(fullname, **kwargs) - - @classmethod - def save_list(cls, cutouts_list, filename=None, overwrite=True, **kwargs): - """Save a list of Cutouts objects into a file. - - Parameters - ---------- - cutouts_list: list of Cutouts - The list of Cutouts objects to save. - filename: str, optional - The (relative/full path) filename to save to. If not given, will use the default filename. - overwrite: bool - If True, will overwrite the file if it already exists. - If False, will raise an error if the file already exists. - kwargs: dict - Any additional keyword arguments to pass to the File - """ - if not isinstance(cutouts_list, list): - raise TypeError("The input must be a list of Cutouts objects.") - if len(cutouts_list) == 0: - return # silently do nothing + proper_length = self.sources.num_sources + if len(self.co_dict) != proper_length: + raise ValueError(f"Trying to save cutouts dict with {len(self.co_dict)}" + f" subdicts, but SourceList has {proper_length} sources") - for cutout in cutouts_list: - if not isinstance(cutout, cls): - raise TypeError("The input must be a list of Cutouts objects.") - if not cutout.has_data: - raise RuntimeError("The Cutouts data is not loaded. Cannot save.") + for key, value in self.co_dict.items(): + if not isinstance(value, dict): + raise TypeError("Each entry of co_dict must be a dictionary") if filename is None: - filename = cutouts_list[0].invent_filepath() + filename = self.invent_filepath() + + self.filepath = filename - fullname = os.path.join(cutouts_list[0].local_path, filename) - cutouts_list[0].safe_mkdir(os.path.dirname(fullname)) + fullname = os.path.join(self.local_path, filename) + self.safe_mkdir(os.path.dirname(fullname)) if not overwrite and os.path.isfile(fullname): raise FileExistsError(f"The file {fullname} already exists and overwrite is False.") - if cutouts_list[0].format == 'hdf5': + if self.format == 'hdf5': with h5py.File(fullname, 'a') as file: - for cutout in cutouts_list: - cutout._save_dataset_to_hdf5(file, f'source_{cutout.index_in_sources}') - cutout.filepath = filename - elif cutouts_list[0].format == 'fits': + for key, value in self.co_dict.items(): + self._save_dataset_dict_to_hdf5(value, file, key) + elif self.format == 'fits': raise NotImplementedError('Saving cutouts to fits is not yet implemented.') - elif cutouts_list[0].format in ['jpg', 'png']: + elif self.format in ['jpg', 'png']: raise NotImplementedError('Saving cutouts to jpg or png is not yet implemented.') else: - raise TypeError(f"Unable to save cutouts file of type {cutouts_list[0].format}") + raise TypeError(f"Unable to save cutouts file of type {self.format}") # make sure to also save using the FileOnDiskMixin method - FileOnDiskMixin.save(cutouts_list[0], fullname, overwrite=overwrite, **kwargs) - - # after saving one object as a FileOnDiskMixin, all the others should have the same md5sum - if cutouts_list[0].md5sum is not None: - for cutout in cutouts_list: - cutout.md5sum = cutouts_list[0].md5sum + FileOnDiskMixin.save(self, fullname, overwrite=overwrite, **kwargs) - def _load_dataset_from_hdf5(self, file, groupname): - """Load the dataset from an HDF5 group into this Cutouts object. + def _load_dataset_dict_from_hdf5(self, file, groupname): + """Load the dataset from an HDF5 group into one co_subdict and return. Parameters ---------- file: h5py.File The open HDF5 file to load from. groupname: str - The name of the group to load from. This should be "source_" + The name of the group to load from. This should be "source_index_" """ - for att in self.get_data_attributes(): - if att == 'source_row': - self.source_row = dict(file[groupname].attrs) - elif att in file[groupname]: - setattr(self, att, np.array(file[f'{groupname}/{att}'])) - self.format = 'hdf5' + co_subdict = {} + found_data = False + for att in self.get_data_dict_attributes(): # remove source index for dict soon + if att in file[groupname]: + found_data = True + co_subdict[att] = np.array(file[f'{groupname}/{att}']) + if found_data: + return co_subdict + + def load_one_co_dict(self, groupname, filepath=None): + """Load data subdict for a single cutout into this Cutouts co_dict. This allows + a measurement to request only the information relevant to that object, rather + than populating the entire dictionary when we only need one subdict. + """ + + if filepath is None: + filepath = self.get_fullpath() + + with h5py.File(filepath, 'r') as file: + co_subdict = self._load_dataset_dict_from_hdf5(file, groupname) + if co_subdict is not None: + self.co_dict[groupname] = co_subdict + return None def load(self, filepath=None): """Load the data for this cutout from a file. @@ -458,337 +393,31 @@ def load(self, filepath=None): The (relative/full path) filename to load from. If not given, will use self.get_fullpath() to get the filename. """ + if filepath is None: filepath = self.get_fullpath() - if self.format == 'hdf5': - with h5py.File(filepath, 'r') as file: - self._load_dataset_from_hdf5(file, f'source_{self.index_in_sources}') - elif self.format == 'fits': - raise NotImplementedError('Loading cutouts from fits is not yet implemented.') - elif self.format in ['jpg', 'png']: - raise NotImplementedError('Loading cutouts from jpg or png is not yet implemented.') - else: - raise TypeError(f"Unable to load cutouts file of type {self.format}") - - @classmethod - def from_file(cls, filepath, source_number, **kwargs): - """Create a Cutouts object from a file. - - Will try to guess the format based on the file extension. - - Parameters - ---------- - filepath: str - The (relative/full path) filename to load from. - source_number: int - The index of the source in the source list from which to create the cutout. - This relates to the internal storage in file. For HDF5 files, the group - for this object will be named "source_{source_number}". - kwargs: dict - Any additional keyword arguments to pass to the Cutouts constructor. - E.g., if you happen to know some database values for this object, - like the ID of related objects or the bitflag, you can pass them here. - """ - cutout = cls(**kwargs) - fmt = os.path.splitext(filepath)[1][1:] - if fmt == 'h5': - fmt = 'hdf5' - - cutout.format = fmt - cutout.index_in_sources = source_number - cutout.load(filepath) - - for att in ['ra', 'dec', 'x', 'y']: - if att in cutout.source_row: - setattr(cutout, att, cutout.source_row[att]) - - cutout.calculate_coordinates() - - if filepath.startswith(cutout.local_path): - filepath = filepath[len(cutout.local_path) + 1:] - cutout.filepath = filepath - - # TODO: should also load the MD5sum automatically? - - return cutout - - @classmethod - def load_list(cls, filepath, cutout_list=None): - """Load all Cutouts object that were saved to a file - - Note that these cutouts are not loaded from the database, - so they will be missing important relationships like provenance and sources. - If cutout_list is given, it must match the cutouts on the file, - so that each cutouts object will be loaded the data from file, - but retain its database relationships. - - Parameters - ---------- - filepath: str - The (relative/full path) filename to load from. - The file format is determined by the extension. - cutout_list: list of Cutouts, optional - If given, will load the data from the file into these objects. - - Returns - ------- - cutouts: Cutouts - The list of cutouts loaded from the file. - """ - ext = os.path.splitext(filepath)[1][1:] - if ext == 'h5': - format = 'hdf5' - else: - format = ext - - if filepath.startswith(Cutouts.local_path): - rel_filepath = filepath[len(Cutouts.local_path) + 1:] - - cutouts = [] - - if format == 'hdf5': - with h5py.File(filepath, 'r') as file: - for groupname in file.keys(): - if groupname.startswith('source_'): - number = int(groupname.split('_')[1]) - if cutout_list is None: - cutout = cls() - cutout.format = format - cutout.index_in_sources = number - else: - cutout = [c for c in cutout_list if c.index_in_sources == number] - if len(cutout) != 1: - raise ValueError(f"Could not find a unique cutout with index {number} in the list.") - cutout = cutout[0] - - cutout._load_dataset_from_hdf5(file, groupname) - cutout.filepath = rel_filepath - for att in ['ra', 'dec', 'x', 'y']: - if att in cutout.source_row: - setattr(cutout, att, cutout.source_row[att]) - - cutout.calculate_coordinates() - - cutouts.append(cutout) - - elif format == 'fits': - raise NotImplementedError('Loading cutouts from fits is not yet implemented.') - elif format in ['jpg', 'png']: - raise NotImplementedError('Loading cutouts from jpg or png is not yet implemented.') - else: - raise TypeError(f"Unable to load cutouts file of type {format}") - - cutouts.sort(key=lambda x: x.index_in_sources) - return cutouts - - def remove_data_from_disk(self, remove_folders=True, remove_downstreams=False): - """Delete the data from local disk, if it exists. - Will remove the dataset for this specific cutout from the file, - and remove the file if this is the last cutout in the file. - If remove_folders=True, will also remove any folders - if they are empty after the deletion. - This function will not remove database rows or archive files, - only cleanup local storage for this object and its downstreams. - - To remove both the files and the database entry, use - delete_from_disk_and_database() instead. - - Parameters - ---------- - remove_folders: bool - If True, will remove any folders on the path to the files - associated to this object, if they are empty. - remove_downstreams: bool - This is not used, but kept here for backward compatibility with the base class. - """ - raise NotImplementedError( - 'Currently there is no support for removing one Cutout at a time. Use delete_list instead.' - ) + if filepath is None: + raise ValueError("Could not find filepath to load") - if self.filepath is not None: - # get the filepath, but don't check if the file exists! - for f in self.get_fullpath(as_list=True, nofile=True): - if os.path.exists(f): - need_to_delete = False - if self.format == 'hdf5': - with h5py.File(f, 'a') as file: - del file[f'source_{self.index_in_sources}'] - if len(file) == 0: - need_to_delete = True - elif self.format == 'fits': - raise NotImplementedError('Removing cutouts from fits is not yet implemented.') - elif self.format in ['jpg', 'png']: - raise NotImplementedError('Removing cutouts from jpg or png is not yet implemented.') - else: - raise TypeError(f"Unable to remove cutouts file of type {self.format}") - - if need_to_delete: - os.remove(f) - if remove_folders: - folder = f - for i in range(10): - folder = os.path.dirname(folder) - if len(os.listdir(folder)) == 0: - os.rmdir(folder) - else: - break - - def delete_from_archive(self, remove_downstreams=False): - """Delete the file from the archive, if it exists. - Will only - This will not remove the file from local disk, nor - from the database. Use delete_from_disk_and_database() - to do that. + self.co_dict = Co_Dict() + self.co_dict.cutouts = self - Parameters - ---------- - remove_downstreams: bool - If True, will also remove any downstream data. - Will recursively call get_downstreams() and find any objects - that have remove_data_from_disk() implemented, and call it. - Default is False. - """ - raise NotImplementedError( - 'Currently archive does not support removing one Cutout at a time, use delete_list instead.' - ) - if self.filepath is not None: - if self.filepath_extensions is None: - self.archive.delete( self.filepath, okifmissing=True ) - else: - for ext in self.filepath_extensions: - self.archive.delete( f"{self.filepath}{ext}", okifmissing=True ) - - # make sure these are set to null just in case we fail - # to commit later on, we will at least know something is wrong - self.md5sum = None - self.md5sum_extensions = None + if os.path.exists(filepath): + if self.format == 'hdf5': + with h5py.File(filepath, 'r') as file: + # quirk: the resulting dict is sorted alphabetically... likely harmless + for groupname in file: + self.co_dict[groupname] = self._load_dataset_dict_from_hdf5(file, groupname) def get_upstreams(self, session=None): """Get the detections SourceList that was used to make this cutout. """ with SmartSession(session) as session: return session.scalars(sa.select(SourceList).where(SourceList.id == self.sources_id)).all() - + def get_downstreams(self, session=None, siblings=False): """Get the downstream Measurements that were made from this Cutouts object. """ from models.measurements import Measurements with SmartSession(session) as session: return session.scalars(sa.select(Measurements).where(Measurements.cutouts_id == self.id)).all() - - @classmethod - def merge_list(cls, cutouts_list, session): - """Merge (or add) the list of Cutouts to the given session. """ - if cutouts_list is None or len(cutouts_list) == 0: - return cutouts_list - - sources = session.merge(cutouts_list[0].sources) - for i, cutout in enumerate(cutouts_list): - cutouts_list[i].sources = sources - cutouts_list[i] = session.merge(cutouts_list[i]) - - return cutouts_list - - @classmethod - def delete_list(cls, cutouts_list, remove_local=True, archive=True, database=True, session=None, commit=True): - """ - Remove a list of Cutouts objects from local disk and/or the archive and/or the database. - This removes the file that includes all the cutouts. - Can only delete cutouts that share the same filepath. - WARNING: this will not check that the file contains ONLY the cutouts on the list! - So, if the list contains a subset of the cutouts on file, the file is still deleted. - - Parameters - ---------- - cutouts_list: list of Cutouts - The list of Cutouts objects to remove. - remove_local: bool - If True, will remove the file from local disk. - archive: bool - If True, will remove the file from the archive. - database: bool - If True, will remove the cutouts from the database. - session: Session, optional - The database session to use. If not given, will create a new session. - commit: bool - If True, will commit the changes to the database. - If False, will not commit the changes to the database. - If session is not given, commit must be True. - """ - if database and session is None and not commit: - raise ValueError('If session is not given, commit must be True.') - - filepath = set([c.filepath for c in cutouts_list]) - if len(filepath) > 1: - raise ValueError( - f'All cutouts must share the same filepath to be deleted together. Got: {filepath}' - ) - - if remove_local: - fullpath = cutouts_list[0].get_fullpath() - if fullpath is not None and os.path.isfile(fullpath): - os.remove(fullpath) - - if archive: - if cutouts_list[0].filepath is not None: - cutouts_list[0].archive.delete(cutouts_list[0].filepath, okifmissing=True) - - if database: - with SmartSession(session) as session: - for cutout in cutouts_list: - cutout.delete_from_database(session=session, commit=False) - if commit: - session.commit() - - def check_equals(self, other): - """Compare if two cutouts have the same data. """ - if not isinstance(other, Cutouts): - return super().__eq__(other) # any other comparisons use the base class - - attributes = self.get_data_attributes() - attributes += ['ra', 'dec', 'x', 'y', 'filepath', 'format'] - - for att in attributes: - if isinstance(getattr(self, att), np.ndarray): - if not np.array_equal(getattr(self, att), getattr(other, att)): - return False - else: # other attributes get compared directly - if getattr(self, att) != getattr(other, att): - return False - - return True - - def _get_inverse_badness(self): - return cutouts_badness_inverse - - -# use these two functions to quickly add the "property" accessor methods -def load_attribute(object, att): - """Load the data for a given attribute of the object.""" - if not hasattr(object, f'_{att}'): - raise AttributeError(f"The object {object} does not have the attribute {att}.") - if getattr(object, f'_{att}') is None: - if object.filepath is None: - return None # objects just now created and not saved cannot lazy load data! - object.load() # can lazy-load all data - - # after data is filled, should be able to just return it - return getattr(object, f'_{att}') - - -def set_attribute(object, att, value): - """Set the value of the attribute on the object. """ - setattr(object, f'_{att}', value) - - -# add "@property" functions to all the data attributes -for att in Cutouts.get_data_attributes(): - setattr( - Cutouts, - att, - property( - fget=lambda self, att=att: load_attribute(self, att), - fset=lambda self, value, att=att: set_attribute(self, att, value), - ) - ) - diff --git a/models/enums_and_bitflags.py b/models/enums_and_bitflags.py index b3673d57..c5d7a76a 100644 --- a/models/enums_and_bitflags.py +++ b/models/enums_and_bitflags.py @@ -410,8 +410,8 @@ def string_to_bitflag(value, dictionary): bg_badness_inverse = {EnumConverter.c(v): k for k, v in bg_badness_dict.items()} -# these are the ways a Cutouts object is allowed to be bad -cutouts_badness_dict = { +# these are the ways a Measurements object is allowed to be bad +measurements_badness_dict = { 41: 'cosmic ray', 42: 'ghost', 43: 'satellite', @@ -419,13 +419,13 @@ def string_to_bitflag(value, dictionary): 45: 'bad pixel', 46: 'bleed trail', } -cutouts_badness_inverse = {EnumConverter.c(v): k for k, v in cutouts_badness_dict.items()} +measurements_badness_inverse = {EnumConverter.c(v): k for k, v in measurements_badness_dict.items()} # join the badness: data_badness_dict = {} data_badness_dict.update(image_badness_dict) -data_badness_dict.update(cutouts_badness_dict) +data_badness_dict.update(measurements_badness_dict) data_badness_dict.update(source_list_badness_dict) data_badness_dict.update(psf_badness_dict) data_badness_dict.update(bg_badness_dict) diff --git a/models/measurements.py b/models/measurements.py index 6b42fda8..c7127d71 100644 --- a/models/measurements.py +++ b/models/measurements.py @@ -10,6 +10,7 @@ from models.base import Base, SeeChangeBase, SmartSession, AutoIDMixin, SpatiallyIndexed, HasBitFlagBadness from models.cutouts import Cutouts +from models.enums_and_bitflags import measurements_badness_inverse from improc.photometry import get_circle @@ -19,7 +20,7 @@ class Measurements(Base, AutoIDMixin, SpatiallyIndexed, HasBitFlagBadness): __tablename__ = 'measurements' __table_args__ = ( - UniqueConstraint('cutouts_id', 'provenance_id', name='_measurements_cutouts_provenance_uc'), + UniqueConstraint('cutouts_id', 'index_in_sources', 'provenance_id', name='_measurements_cutouts_provenance_uc'), sa.Index("ix_measurements_scores_gin", "disqualifier_scores", postgresql_using="gin"), ) @@ -38,6 +39,13 @@ class Measurements(Base, AutoIDMixin, SpatiallyIndexed, HasBitFlagBadness): doc="The cutouts object that this measurements object is associated with. " ) + index_in_sources = sa.Column( + sa.Integer, + nullable=False, + doc="Index of the data for this Measurements" + "in the source list (of detections in the difference image). " + ) + object_id = sa.Column( sa.ForeignKey('objects.id', ondelete="CASCADE", name='measurements_object_id_fkey'), nullable=False, # every saved Measurements object must have an associated Object @@ -178,23 +186,23 @@ def magnitude_err(self): @property def lim_mag(self): - return self.cutouts.sources.image.new_image.lim_mag_estimate # TODO: improve this when done with issue #143 + return self.sources.image.new_image.lim_mag_estimate # TODO: improve this when done with issue #143 @property def zp(self): - return self.cutouts.sources.image.new_image.zp + return self.sources.image.new_image.zp @property def fwhm_pixels(self): - return self.cutouts.sources.image.get_psf().fwhm_pixels + return self.sources.image.get_psf().fwhm_pixels @property def psf(self): - return self.cutouts.sources.image.get_psf().get_clip(x=self.cutouts.x, y=self.cutouts.y) + return self.sources.image.get_psf().get_clip(x=self.center_x_pixel, y=self.center_y_pixel) @property def pixel_scale(self): - return self.cutouts.sources.image.new_image.wcs.get_pixel_scale() + return self.sources.image.new_image.wcs.get_pixel_scale() @property def sources(self): @@ -204,15 +212,15 @@ def sources(self): @property def image(self): - if self.cutouts is None or self.cutouts.sources is None: + if self.cutouts is None or self.sources is None: return None - return self.cutouts.sources.image + return self.sources.image @property def instrument_object(self): - if self.cutouts is None or self.cutouts.sources is None or self.cutouts.sources.image is None: + if self.cutouts is None or self.sources is None or self.sources.image is None: return None - return self.cutouts.sources.image.instrument_object + return self.sources.image.instrument_object bkg_mean = sa.Column( sa.REAL, @@ -245,6 +253,20 @@ def instrument_object(self): doc="Areas of the apertures used for calculating flux. Remove a * background from the flux measurement. " ) + center_x_pixel = sa.Column( + sa.Integer, + nullable=False, + doc="X pixel coordinate of the center of the cutout (in the full image coordinates)," + "rounded to nearest integer pixel. " + ) + + center_y_pixel = sa.Column( + sa.Integer, + nullable=False, + doc="Y pixel coordinate of the center of the cutout (in the full image coordinates)," + "rounded to nearest integer pixel. " + ) + offset_x = sa.Column( sa.REAL, nullable=False, @@ -297,10 +319,43 @@ def instrument_object(self): "The higher the score, the more likely the measurement is to be an artefact. " ) + @property + def sub_nandata(self): + if self.sub_data is None or self.sub_flags is None: + return None + return np.where(self.sub_flags > 0, np.nan, self.sub_data) + + @property + def ref_nandata(self): + if self.ref_data is None or self.ref_flags is None: + return None + return np.where(self.ref_flags > 0, np.nan, self.ref_data) + + @property + def new_nandata(self): + if self.new_data is None or self.new_flags is None: + return None + return np.where(self.new_flags > 0, np.nan, self.new_data) + def __init__(self, **kwargs): SeeChangeBase.__init__(self) # don't pass kwargs as they could contain non-column key-values HasBitFlagBadness.__init__(self) - self._cutouts_list_index = None # helper (transient) attribute that helps find the right cutouts in a list + + self.index_in_sources = None + + self._sub_data = None + self._sub_weight = None + self._sub_flags = None + self._sub_psfflux = None + self._sub_psffluxerr = None + + self._ref_data = None + self._ref_weight = None + self._ref_flags = None + + self._new_data = None + self._new_weight = None + self._new_flags = None # manually set all properties (columns or not) for key, value in kwargs.items(): @@ -309,21 +364,61 @@ def __init__(self, **kwargs): self.calculate_coordinates() + @orm.reconstructor + def init_on_load(self): + Base.init_on_load(self) + + self._sub_data = None + self._sub_weight = None + self._sub_flags = None + self._sub_psfflux = None + self._sub_psffluxerr = None + + self._ref_data = None + self._ref_weight = None + self._ref_flags = None + + self._new_data = None + self._new_weight = None + self._new_flags = None + def __repr__(self): return ( f"" + f"at x,y= {self.center_x_pixel}, {self.center_y_pixel}>" ) def __setattr__(self, key, value): if key in ['flux_apertures', 'flux_apertures_err', 'aper_radii']: value = np.array(value) + if key in ['center_x_pixel', 'center_y_pixel'] and value is not None: + value = int(np.round(value)) + super().__setattr__(key, value) + def get_data_from_cutouts(self): + """Populates this object with the cutout data arrays used in + calculations. This allows us to use, for example, self.sub_data + without having to look constantly back into the related Cutouts. + + Importantly, the data for this measurements should have already + been loaded by the Co_Dict class + """ + groupname = f'source_index_{self.index_in_sources}' + + if not self.cutouts.co_dict.get(groupname): + raise ValueError(f"No subdict found for {groupname}") + + co_data_dict = self.cutouts.co_dict[groupname] # get just the subdict with data for this + + for att in Cutouts.get_data_dict_attributes(): + setattr(self, att, co_data_dict.get(att)) + + def get_filter_description(self, number=None): """Use the number of the filter in the filter bank to get a string describing it. @@ -342,12 +437,12 @@ def get_filter_description(self, number=None): raise ValueError('Filter number must be non-negative.') if self.provenance is None: raise ValueError('No provenance for this measurement, cannot recover the parameters used. ') - if self.cutouts is None or self.cutouts.sources is None or self.cutouts.sources.image is None: + if self.cutouts is None or self.sources is None or self.sources.image is None: raise ValueError('No cutouts for this measurement, cannot recover the PSF width. ') mult = self.provenance.parameters['width_filter_multipliers'] angles = np.arange(-90.0, 90.0, self.provenance.parameters['streak_filter_angle_step']) - fwhm = self.cutouts.sources.image.get_psf().fwhm_pixels + fwhm = self.sources.image.get_psf().fwhm_pixels if number == 0: return f'PSF match (FWHM= 1.00 x {fwhm:.2f})' @@ -360,21 +455,6 @@ def get_filter_description(self, number=None): raise ValueError('Filter number too high for the filter bank. ') - def find_cutouts_in_list(self, cutouts_list): - """Given a list of cutouts, find the one that matches this object. """ - # this is faster, and works without needing DB indices to be set - if self._cutouts_list_index is not None: - return cutouts_list[self._cutouts_list_index] - - # after loading from DB (or merging) we must use the cutouts_id to associate these - if self.cutouts_id is not None: - for i, cutouts in enumerate(cutouts_list): - if cutouts.id == self.cutouts_id: - self._cutouts_list_index = i - return cutouts - - raise ValueError('Cutouts not found in the list. ') - def associate_object(self, session=None): """Find or create a new object and associate it with this measurement. @@ -446,15 +526,15 @@ def get_flux_at_point(self, ra, dec, aperture=None): if aperture == 'psf': aperture = -1 - im = self.cutouts.sub_nandata # the cutouts image we are working with (includes NaNs for bad pixels) + im = self.sub_nandata # the cutouts image we are working with (includes NaNs for bad pixels) - wcs = self.cutouts.sources.image.new_image.wcs.wcs + wcs = self.sources.image.new_image.wcs.wcs # these are the coordinates relative to the center of the cutouts image_pixel_x = wcs.world_to_pixel_values(ra, dec)[0] image_pixel_y = wcs.world_to_pixel_values(ra, dec)[1] - offset_x = image_pixel_x - self.cutouts.x - offset_y = image_pixel_y - self.cutouts.y + offset_x = image_pixel_x - self.center_x_pixel + offset_y = image_pixel_y - self.center_y_pixel if abs(offset_x) > im.shape[1] / 2 or abs(offset_y) > im.shape[0] / 2: return np.nan, np.nan, np.nan # quietly return NaNs for large offsets, they will fail the cuts anyway... @@ -464,7 +544,7 @@ def get_flux_at_point(self, ra, dec, aperture=None): if aperture == -1: # get the subtraction PSF or (if unavailable) the new image PSF - psf = self.cutouts.sources.image.get_psf() + psf = self.sources.image.get_psf() psf_clip = psf.get_clip(x=image_pixel_x, y=image_pixel_y) offset_ix = int(np.round(offset_x)) offset_iy = int(np.round(offset_y)) @@ -505,6 +585,9 @@ def get_downstreams(self, session=None, siblings=False): """Get the downstreams of this Measurements""" return [] + def _get_inverse_badness(self): + return measurements_badness_inverse + @classmethod def delete_list(cls, measurements_list, session=None, commit=True): """ @@ -530,3 +613,35 @@ def delete_list(cls, measurements_list, session=None, commit=True): if commit: session.commit() +# use these three functions to quickly add the "property" accessor methods +def load_attribute(object, att): + """Load the data for a given attribute of the object. Load from Cutouts, but + if the data needs to be loaded from disk, ONLY load the subdict that contains + data for this object, not all objects in the Cutouts.""" + if not hasattr(object, f'_{att}'): + raise AttributeError(f"The object {object} does not have the attribute {att}.") + if getattr(object, f'_{att}') is None: + if len(object.cutouts.co_dict) == 0 and object.cutouts.filepath is None: + return None # objects just now created and not saved cannot lazy load data! + + groupname = f'source_index_{object.index_in_sources}' + if object.cutouts.co_dict[groupname] is not None: # will check disk as Co_Dict + object.get_data_from_cutouts() + + # after data is filled, should be able to just return it + return getattr(object, f'_{att}') + +def set_attribute(object, att, value): + """Set the value of the attribute on the object. """ + setattr(object, f'_{att}', value) + +# add "@property" functions to all the data attributes +for att in Cutouts.get_data_dict_attributes(): + setattr( + Measurements, + att, + property( + fget=lambda self, att=att: load_attribute(self, att), + fset=lambda self, value, att=att: set_attribute(self, att, value), + ) + ) diff --git a/models/source_list.py b/models/source_list.py index 8bc19bbc..26819f47 100644 --- a/models/source_list.py +++ b/models/source_list.py @@ -191,7 +191,7 @@ def merge_all(self, session): """ new_sources = self.safe_merge(session=session) session.flush() - for att in ['wcs', 'zp']: + for att in ['wcs', 'zp', 'cutouts']: sub_obj = getattr(self, att, None) if sub_obj is not None: sub_obj.sources = new_sources # make sure to first point this relationship back to new_sources @@ -200,7 +200,7 @@ def merge_all(self, session): sub_obj = sub_obj.safe_merge(session=session) setattr(new_sources, att, sub_obj) - for att in ['cutouts', 'measurements']: + for att in ['measurements']: sub_obj = getattr(self, att, None) if sub_obj is not None: new_list = [] diff --git a/pipeline/cutting.py b/pipeline/cutting.py index 704c28d1..b98affeb 100644 --- a/pipeline/cutting.py +++ b/pipeline/cutting.py @@ -73,9 +73,12 @@ def run(self, *args, **kwargs): prov = ds.get_provenance('cutting', self.pars.get_critical_pars(), session=session) # try to find some measurements in memory or in the database: - cutout_list = ds.get_cutouts(prov, session=session) + cutouts = ds.get_cutouts(prov, session=session) + if cutouts is not None: + cutouts.load_all_co_data() + + if cutouts is None or len(cutouts.co_dict) == 0: - if cutout_list is None or len(cutout_list) == 0: # must create a new list of Cutouts self.has_recalculated = True # use the latest source list in the data store, # or load using the provenance given in the @@ -88,7 +91,6 @@ def run(self, *args, **kwargs): f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}' ) - cutout_list = [] x = detections.x y = detections.y sz = self.pars.cutout_size @@ -112,42 +114,42 @@ def run(self, *args, **kwargs): new_stamps_weight = make_cutouts(ds.sub_image.new_aligned_image.weight, x, y, sz, fillvalue=0) new_stamps_flags = make_cutouts(ds.sub_image.new_aligned_image.flags, x, y, sz, fillvalue=0) + cutouts = Cutouts.from_detections(detections, provenance=prov) + + cutouts._upstream_bitflag = 0 + cutouts._upstream_bitflag |= detections.bitflag + for i, source in enumerate(detections.data): - # get the cutouts - cutout = Cutouts.from_detections(detections, i, provenance=prov) - cutout.sub_data = sub_stamps_data[i] - cutout.sub_weight = sub_stamps_weight[i] - cutout.sub_flags = sub_stamps_flags[i] + data_dict = {} + data_dict["sub_data"] = sub_stamps_data[i] + data_dict["sub_weight"] = sub_stamps_weight[i] + data_dict["sub_flags"] = sub_stamps_flags[i] # TODO: figure out if we can actually use this flux (maybe renormalize it) # if sub_stamps_psfflux is not None and sub_stamps_psffluxerr is not None: - # cutout.sub_psfflux = sub_stamps_psfflux[i] - # cutout.sub_psffluxerr = sub_stamps_psffluxerr[i] - - cutout.ref_data = ref_stamps_data[i] - cutout.ref_weight = ref_stamps_weight[i] - cutout.ref_flags = ref_stamps_flags[i] - - cutout.new_data = new_stamps_data[i] - cutout.new_weight = new_stamps_weight[i] - cutout.new_flags = new_stamps_flags[i] - - cutout._upstream_bitflag = 0 - cutout._upstream_bitflag |= detections.bitflag - - cutout_list.append(cutout) - - # add the resulting list to the data store - for cutout in cutout_list: - if cutout.provenance is None: - cutout.provenance = prov - else: - if cutout.provenance.id != prov.id: - raise ValueError( - f'Provenance mismatch for cutout {cutout.provenance.id[:6]} ' + # data_dict['sub_psfflux'] = sub_stamps_psfflux[i] + # data_dict['sub_psffluxerr'] = sub_stamps_psffluxerr[i] + + data_dict["ref_data"] = ref_stamps_data[i] + data_dict["ref_weight"] = ref_stamps_weight[i] + data_dict["ref_flags"] = ref_stamps_flags[i] + + data_dict["new_data"] = new_stamps_data[i] + data_dict["new_weight"] = new_stamps_weight[i] + data_dict["new_flags"] = new_stamps_flags[i] + cutouts.co_dict[f"source_index_{i}"] = data_dict + + + # add the resulting Cutouts to the data store + if cutouts.provenance is None: + cutouts.provenance = prov + else: + if cutouts.provenance.id != prov.id: + raise ValueError( + f'Provenance mismatch for cutout {cutouts.provenance.id[:6]} ' f'and preset provenance {prov.id[:6]}!' ) - ds.cutouts = cutout_list + ds.cutouts = cutouts ds.runtimes['cutting'] = time.perf_counter() - t_start if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): @@ -158,4 +160,3 @@ def run(self, *args, **kwargs): ds.catch_exception(e) finally: # make sure datastore is returned to be used in the next step return ds - diff --git a/pipeline/data_store.py b/pipeline/data_store.py index addb4cb1..62b5c9b1 100644 --- a/pipeline/data_store.py +++ b/pipeline/data_store.py @@ -389,11 +389,8 @@ def __setattr__(self, key, value): if key == 'detections' and not isinstance(value, SourceList): raise ValueError(f'detections must be a SourceList object, got {type(value)}') - if key == 'cutouts' and not isinstance(value, list): - raise ValueError(f'cutouts must be a list of Cutout objects, got {type(value)}') - - if key == 'cutouts' and not all([isinstance(c, Cutouts) for c in value]): - raise ValueError(f'cutouts must be a list of Cutouts objects, got list with {[type(c) for c in value]}') + if key == 'cutouts' and not isinstance(value, Cutouts): + raise ValueError(f'cutouts must be a Cutouts object, got {type(value)}') if key == 'measurements' and not isinstance(value, list): raise ValueError(f'measurements must be a list of Measurements objects, got {type(value)}') @@ -1285,16 +1282,16 @@ def get_cutouts(self, provenance=None, session=None): if provenance is None: # try to get the provenance from the prov_tree provenance = self._get_provenance_for_an_upstream(process_name, session) - # not in memory, look for it on the DB if self.cutouts is not None: - if len(self.cutouts) == 0: + self.cutouts.load_all_co_data() + if len(self.cutouts.co_dict) == 0: self.cutouts = None # TODO: what about images that actually don't have any detections? # make sure the cutouts have the correct provenance if self.cutouts is not None: - if self.cutouts[0].provenance is None: + if self.cutouts.provenance is None: raise ValueError('Cutouts have no provenance!') - if provenance is not None and provenance.id != self.cutouts[0].provenance.id: + if provenance is not None and provenance.id != self.cutouts.provenance.id: self.cutouts = None # not in memory, look for it on the DB @@ -1316,7 +1313,7 @@ def get_cutouts(self, provenance=None, session=None): Cutouts.sources_id == sub_image.sources.id, Cutouts.provenance_id == provenance.id, ) - ).all() + ).first() return self.cutouts @@ -1360,11 +1357,10 @@ def get_measurements(self, provenance=None, session=None): if self.measurements is None: with SmartSession(session, self.session) as session: cutouts = self.get_cutouts(session=session) - cutout_ids = [c.id for c in cutouts] self.measurements = session.scalars( sa.select(Measurements).where( - Measurements.cutouts_id.in_(cutout_ids), + Measurements.cutouts_id == cutouts.id, Measurements.provenance_id == provenance.id, ) ).all() @@ -1485,11 +1481,6 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False, if obj is None: continue - if isinstance(obj, list) and len(obj) > 0: # handle cutouts and measurements - if hasattr(obj[0], 'save_list'): - obj[0].save_list(obj, overwrite=overwrite, exists_ok=exists_ok, no_archive=no_archive) - continue - SCLogger.debug( f'save_and_commit considering a {obj.__class__.__name__} with filepath ' f'{obj.filepath if isinstance(obj,FileOnDiskMixin) else ""}' ) @@ -1569,19 +1560,14 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False, if self.detections is not None: more_products = 'detections' if self.cutouts is not None: - if self.measurements is not None: # keep track of which cutouts goes to which measurements - for m in self.measurements: - idx = [c.index_in_sources for c in self.cutouts].index(m.cutouts.index_in_sources) - m._cutouts_list_index = idx - for cutout in self.cutouts: - cutout.sources = self.detections - self.cutouts = Cutouts.merge_list(self.cutouts, session) + self.cutouts.sources = self.detections + self.cutouts = session.merge(self.cutouts) more_products += ', cutouts' if self.measurements is not None: for i, m in enumerate(self.measurements): # use the new, merged cutouts - self.measurements[i].cutouts = self.measurements[i].find_cutouts_in_list(self.cutouts) + self.measurements[i].cutouts = self.cutouts self.measurements[i].associate_object(session) self.measurements[i] = session.merge(self.measurements[i]) self.measurements[i].object.measurements.append(self.measurements[i]) diff --git a/pipeline/measuring.py b/pipeline/measuring.py index aefedccf..0eab1912 100644 --- a/pipeline/measuring.py +++ b/pipeline/measuring.py @@ -5,6 +5,8 @@ from scipy import signal +from astropy.table import Table + from improc.photometry import iterative_cutouts_photometry from improc.tools import make_gaussian @@ -156,17 +158,11 @@ def run(self, *args, **kwargs): """ self.has_recalculated = False try: # first make sure we get back a datastore, even an empty one - # most likely to get a Cutouts object or list of Cutouts if isinstance(args[0], Cutouts): - new_args = [args[0]] # make it a list if we got a single Cutouts object for some reason - new_args += list(args[1:]) - args = tuple(new_args) - - if isinstance(args[0], list) and all([isinstance(c, Cutouts) for c in args[0]]): args, kwargs, session = parse_session(*args, **kwargs) ds = DataStore() ds.cutouts = args[0] - ds.detections = ds.cutouts[0].sources + ds.detections = ds.cutouts.sources ds.sub_image = ds.detections.image ds.image = ds.sub_image.new_image else: @@ -201,38 +197,42 @@ def run(self, *args, **kwargs): raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') cutouts = ds.get_cutouts(session=session) + if cutouts is not None: + cutouts.load_all_co_data() # prepare the filter bank for this batch of cutouts - if self._filter_psf_fwhm is None or self._filter_psf_fwhm != cutouts[0].sources.image.get_psf().fwhm_pixels: - self.make_filter_bank(cutouts[0].sub_data.shape[0], cutouts[0].sources.image.get_psf().fwhm_pixels) + if self._filter_psf_fwhm is None or self._filter_psf_fwhm != cutouts.sources.image.get_psf().fwhm_pixels: + self.make_filter_bank(cutouts.co_dict["source_index_0"]["sub_data"].shape[0], cutouts.sources.image.get_psf().fwhm_pixels) # go over each cutouts object and produce a measurements object measurements_list = [] - for i, c in enumerate(cutouts): - m = Measurements(cutouts=c) - # make sure to remember which cutout belongs to this measurement, - # before either of them is in the DB and then use the cutouts_id instead - m._cutouts_list_index = i - m.best_aperture = c.sources.best_aper_num + for key, co_subdict in cutouts.co_dict.items(): + m = Measurements(cutouts=cutouts) + m.index_in_sources = int(key[13:]) # grab just the number from "source_index_xxx" + + m.best_aperture = cutouts.sources.best_aper_num + + m.center_x_pixel = cutouts.sources.x[m.index_in_sources] # These will be rounded by Measurements.__setattr__ + m.center_y_pixel = cutouts.sources.y[m.index_in_sources] - m.aper_radii = c.sources.image.new_image.zp.aper_cor_radii # zero point corrected aperture radii + m.aper_radii = cutouts.sources.image.new_image.zp.aper_cor_radii # zero point corrected aperture radii ignore_bits = 0 for badness in self.pars.bad_pixel_exclude: ignore_bits |= 2 ** BitFlagConverter.convert(badness) # remove the bad pixels that we want to ignore - flags = c.sub_flags.astype('uint16') & ~np.array(ignore_bits).astype('uint16') + flags = co_subdict['sub_flags'].astype('uint16') & ~np.array(ignore_bits).astype('uint16') annulus_radii_pixels = self.pars.annulus_radii if self.pars.annulus_units == 'fwhm': - fwhm = c.source.image.get_psf().fwhm_pixels + fwhm = cutouts.source.image.get_psf().fwhm_pixels annulus_radii_pixels = [rad * fwhm for rad in annulus_radii_pixels] # TODO: consider if there are any additional parameters that photometry needs output = iterative_cutouts_photometry( - c.sub_data, - c.sub_weight, + co_subdict['sub_data'], + co_subdict['sub_weight'], flags, radii=m.aper_radii, annulus=annulus_radii_pixels, @@ -253,11 +253,12 @@ def run(self, *args, **kwargs): m.position_angle = output['angle'] # update the coordinates using the centroid offsets - x = c.x + m.offset_x - y = c.y + m.offset_y + x = m.center_x_pixel + m.offset_x + y = m.center_y_pixel + m.offset_y ra, dec = m.cutouts.sources.image.new_image.wcs.wcs.pixel_to_world_values(x, y) m.ra = float(ra) m.dec = float(dec) + m.calculate_coordinates() # PSF photometry: # Two options: use the PSF flux from ZOGY, or use the new image PSF to measure the flux. @@ -294,10 +295,10 @@ def run(self, *args, **kwargs): # Apply analytic cuts to each stamp image, to rule out artefacts. m.disqualifier_scores = {} if m.bkg_mean != 0 and m.bkg_std > 0.1: - norm_data = (c.sub_nandata - m.bkg_mean) / m.bkg_std # normalize + norm_data = (m.sub_nandata - m.bkg_mean) / m.bkg_std # normalize else: warnings.warn(f'Background mean= {m.bkg_mean}, std= {m.bkg_std}, normalization skipped!') - norm_data = c.sub_nandata # no good background measurement, do not normalize! + norm_data = m.sub_nandata # no good background measurement, do not normalize! positives = np.sum(norm_data > self.pars.outlier_sigma) negatives = np.sum(norm_data < -self.pars.outlier_sigma) @@ -308,9 +309,9 @@ def run(self, *args, **kwargs): else: m.disqualifier_scores['negatives'] = negatives / positives - x, y = np.meshgrid(range(c.sub_data.shape[0]), range(c.sub_data.shape[1])) - x = x - c.sub_data.shape[1] // 2 - m.offset_x - y = y - c.sub_data.shape[0] // 2 - m.offset_y + x, y = np.meshgrid(range(m.sub_data.shape[0]), range(m.sub_data.shape[1])) + x = x - m.sub_data.shape[1] // 2 - m.offset_x + y = y - m.sub_data.shape[0] // 2 - m.offset_y r = np.sqrt(x ** 2 + y ** 2) bad_pixel_inclusion = r <= self.pars.bad_pixel_radius + 0.5 m.disqualifier_scores['bad pixels'] = np.sum(flags[bad_pixel_inclusion] > 0) @@ -330,7 +331,7 @@ def run(self, *args, **kwargs): # TODO: add additional disqualifiers m._upstream_bitflag = 0 - m._upstream_bitflag |= c.bitflag + m._upstream_bitflag |= cutouts.bitflag ignore_bits = 0 for badness in self.pars.bad_flag_exclude: diff --git a/tests/fixtures/pipeline_objects.py b/tests/fixtures/pipeline_objects.py index fa4c2269..ef12a7a7 100644 --- a/tests/fixtures/pipeline_objects.py +++ b/tests/fixtures/pipeline_objects.py @@ -900,22 +900,22 @@ def make_datastore( cache_name = os.path.join(cache_dir, cache_sub_name + f'.cutouts_{prov.id[:6]}.h5') if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_name) ): SCLogger.debug('loading cutouts from cache. ') - ds.cutouts = copy_list_from_cache(Cutouts, cache_dir, cache_name) - ds.cutouts = Cutouts.load_list(os.path.join(ds.cutouts[0].local_path, ds.cutouts[0].filepath)) - [setattr(c, 'provenance', prov) for c in ds.cutouts] - [setattr(c, 'sources', ds.detections) for c in ds.cutouts] - Cutouts.save_list(ds.cutouts) # make sure to save to archive as well + ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name) + setattr(ds.cutouts, 'provenance', prov) + setattr(ds.cutouts, 'sources', ds.detections) + ds.cutouts.load_all_co_data() # sources must be set first + ds.cutouts.save() else: # cannot find cutouts on cache ds = p.cutter.run(ds, session) - Cutouts.save_list(ds.cutouts) + ds.cutouts.save() if not os.getenv( "LIMIT_CACHE_USAGE" ): - copy_list_to_cache(ds.cutouts, cache_dir) + copy_to_cache(ds.cutouts, cache_dir) ############ measuring to create measurements ############ prov = Provenance( code_version=code_version, process='measuring', - upstreams=[ds.cutouts[0].provenance], + upstreams=[ds.cutouts.provenance], parameters=p.measurer.pars.get_critical_pars(), is_testing=True, ) @@ -929,7 +929,7 @@ def make_datastore( SCLogger.debug('loading measurements from cache. ') ds.all_measurements = copy_list_from_cache(Measurements, cache_dir, cache_name) [setattr(m, 'provenance', prov) for m in ds.all_measurements] - [setattr(m, 'cutouts', c) for m, c in zip(ds.all_measurements, ds.cutouts)] + [setattr(m, 'cutouts', ds.cutouts) for m in ds.all_measurements] ds.measurements = [] for m in ds.all_measurements: diff --git a/tests/fixtures/simulated.py b/tests/fixtures/simulated.py index 2ccf3d31..228a1b24 100644 --- a/tests/fixtures/simulated.py +++ b/tests/fixtures/simulated.py @@ -607,7 +607,7 @@ def sim_sub_image_list( ds = cutter.run(sub.sources) sub.sources.cutouts = ds.cutouts - Cutouts.save_list(ds.cutouts) + ds.cutouts.save() sub = sub.merge_all(session) ds.detections = sub.sources diff --git a/tests/models/test_cutouts.py b/tests/models/test_cutouts.py index 2bb3b0e6..2f2d41b3 100644 --- a/tests/models/test_cutouts.py +++ b/tests/models/test_cutouts.py @@ -1,6 +1,6 @@ import os -import h5py import uuid +import h5py import numpy as np import pytest @@ -10,103 +10,100 @@ from models.base import SmartSession from models.cutouts import Cutouts - def test_make_save_load_cutouts(decam_detection_list, cutter): try: cutter.pars.test_parameter = uuid.uuid4().hex ds = cutter.run(decam_detection_list) + assert cutter.has_recalculated - assert isinstance(ds.cutouts, list) - assert len(ds.cutouts) > 1 - assert isinstance(ds.cutouts[0], Cutouts) - - c = ds.cutouts[0] - assert c.sub_image == decam_detection_list.image - assert c.ref_image == decam_detection_list.image.ref_aligned_image - assert c.new_image == decam_detection_list.image.new_aligned_image - - assert isinstance(c.sub_data, np.ndarray) - assert isinstance(c.sub_weight, np.ndarray) - assert isinstance(c.sub_flags, np.ndarray) - assert isinstance(c.ref_data, np.ndarray) - assert isinstance(c.ref_weight, np.ndarray) - assert isinstance(c.ref_flags, np.ndarray) - assert isinstance(c.new_data, np.ndarray) - assert isinstance(c.new_weight, np.ndarray) - assert isinstance(c.new_flags, np.ndarray) - assert isinstance(c.source_row, dict) - assert c.bitflag is not None + assert isinstance(ds.cutouts, Cutouts) + assert len(ds.cutouts.co_dict) == ds.cutouts.sources.num_sources + + subdict_key = "source_index_0" + co_subdict = ds.cutouts.co_dict[subdict_key] + + assert ds.cutouts.sub_image == decam_detection_list.image + assert ds.cutouts.ref_image == decam_detection_list.image.ref_aligned_image + assert ds.cutouts.new_image == decam_detection_list.image.new_aligned_image + + assert isinstance(co_subdict["sub_data"], np.ndarray) + assert isinstance(co_subdict["sub_weight"], np.ndarray) + assert isinstance(co_subdict["sub_flags"], np.ndarray) + assert isinstance(co_subdict["ref_data"], np.ndarray) + assert isinstance(co_subdict["ref_weight"], np.ndarray) + assert isinstance(co_subdict["ref_flags"], np.ndarray) + assert isinstance(co_subdict["new_data"], np.ndarray) + assert isinstance(co_subdict["new_weight"], np.ndarray) + assert isinstance(co_subdict["new_flags"], np.ndarray) + assert ds.cutouts.bitflag is not None # set the bitflag just to see if it is loaded or not - c.bitflag = 2 ** 41 # should be Cosmic Ray + ds.cutouts.bitflag = 2 ** 41 # should be Cosmic Ray - # save an individual cutout - Cutouts.save_list([c]) + # save the Cutouts + ds.cutouts.save() # open the file manually and compare - with h5py.File(c.get_fullpath(), 'r') as file: - assert 'source_0' in file + with h5py.File(ds.cutouts.get_fullpath(), 'r') as file: for im in ['sub', 'ref', 'new']: for att in ['data', 'weight', 'flags']: - assert f'{im}_{att}' in file['source_0'] - assert np.array_equal(getattr(c, f'{im}_{att}'), file['source_0'][f'{im}_{att}']) - assert dict(file['source_0'].attrs) == c.source_row + assert f'{im}_{att}' in file[subdict_key] + assert np.array_equal(co_subdict.get(f'{im}_{att}'), + file[subdict_key][f'{im}_{att}']) - # load it from file and compare - c2 = Cutouts.from_file(c.get_fullpath(), source_number=0) - assert c.check_equals(c2) - - assert c2.bitflag == 0 # should not load all column data from file (e.g., bitflag) - - # save a second cutout to the same file - Cutouts.save_list(ds.cutouts[1:2]) - assert ds.cutouts[1].filepath == c.filepath - - # change the value of one of the arrays - c.sub_data[0, 0] = 100 - # make sure we can re-save - Cutouts.save_list([c]) + # load a cutouts from file and compare + c2 = Cutouts() + c2.filepath = ds.cutouts.filepath + c2.sources = ds.cutouts.sources # necessary for co_dict + c2.load_all_co_data() # explicitly load co_dict - with h5py.File(c.get_fullpath(), 'r') as file: - assert np.array_equal(c.sub_data, file['source_0']['sub_data']) - assert file['source_0']['sub_data'][0, 0] == 100 # change has been propagated + co_subdict2 = c2.co_dict[subdict_key] - # save the whole list of cutouts - Cutouts.save_list(ds.cutouts) + for im in ['sub', 'ref', 'new']: + for att in ['data', 'weight', 'flags']: + assert np.array_equal(co_subdict.get(f'{im}_{att}'), + co_subdict2.get(f'{im}_{att}')) - # load it from file and compare - loaded_cutouts = Cutouts.load_list(c.get_fullpath()) + assert c2.bitflag == 0 # should not load all column data from file - for cut1, cut2 in zip(ds.cutouts, loaded_cutouts): - assert cut1.check_equals(cut2) + # change the value of one of the arrays + ds.cutouts.co_dict[subdict_key]['sub_data'][0, 0] = 100 + co_subdict2['sub_data'][0, 0] = 100 # for comparison later - # make sure that deleting one cutout does not delete the file - with pytest.raises(NotImplementedError, match='no support for removing one Cutout at a time'): - # TODO: fix this if we ever bring back this functionality - ds.cutouts[1].remove_data_from_disk() - assert os.path.isfile(ds.cutouts[0].get_fullpath()) + # make sure we can re-save + ds.cutouts.save() - # delete one file from the archive, should still keep the file: - # TODO: this is not yet implemented! see issue #207 - # ds.cutouts[1].delete_from_archive() - # TODO: check that the file still exists on the archive + with h5py.File(ds.cutouts.get_fullpath(), 'r') as file: + assert np.array_equal(ds.cutouts.co_dict[subdict_key]['sub_data'], + file[subdict_key]['sub_data']) + assert file[subdict_key]['sub_data'][0, 0] == 100 # change has propagated # check that we can add the cutouts to the database with SmartSession() as session: - ds.cutouts = Cutouts.merge_list(ds.cutouts, session=session) + ds.cutouts = session.merge(ds.cutouts) + session.commit() + ds.cutouts.load_all_co_data() # need to re-load after merge assert ds.cutouts is not None - assert len(ds.cutouts) > 0 + assert len(ds.cutouts.co_dict) > 0 with SmartSession() as session: loaded_cutouts = session.scalars( - sa.select(Cutouts).where(Cutouts.provenance_id == ds.cutouts[0].provenance.id) + sa.select(Cutouts).where(Cutouts.provenance_id == ds.cutouts.provenance.id) ).all() - for cut1, cut2 in zip(ds.cutouts, loaded_cutouts): - assert cut1.check_equals(cut2) + assert len(loaded_cutouts) == 1 + loaded_cutouts = loaded_cutouts[0] + + # make sure data is correct + loaded_cutouts.load_all_co_data() + co_subdict = loaded_cutouts.co_dict[subdict_key] + for im in ['sub', 'ref', 'new']: + for att in ['data', 'weight', 'flags']: + assert np.array_equal(co_subdict.get(f'{im}_{att}'), + co_subdict2.get(f'{im}_{att}')) + finally: if 'ds' in locals() and ds.cutouts is not None: - Cutouts.delete_list(ds.cutouts) - + ds.cutouts.delete_from_disk_and_database() diff --git a/tests/models/test_image.py b/tests/models/test_image.py index 882ac708..d2d2cfca 100644 --- a/tests/models/test_image.py +++ b/tests/models/test_image.py @@ -1345,7 +1345,6 @@ def test_image_multifile(sim_image_uncommitted, provenance_base, test_config): test_config.set_value('storage.images.single_file', single_fileness) -@pytest.mark.skip(reason="This test is way too slow (see Issue #291") def test_image_products_are_deleted(ptf_datastore, data_dir, archive): ds = ptf_datastore # shorthand diff --git a/tests/models/test_measurements.py b/tests/models/test_measurements.py index 116962a7..f02f4e4d 100644 --- a/tests/models/test_measurements.py +++ b/tests/models/test_measurements.py @@ -18,7 +18,7 @@ def test_measurements_attributes(measurer, ptf_datastore, test_config): aper_radii = test_config.value('extraction.sources.apertures') ds = measurer.run(ptf_datastore.cutouts) # check that the measurer actually loaded the measurements from db, and not recalculated - assert len(ds.measurements) <= len(ds.cutouts) # not all cutouts have saved measurements + assert len(ds.measurements) <= len(ds.cutouts.co_dict) # not all cutouts have saved measurements assert len(ds.measurements) == len(ptf_datastore.measurements) assert ds.measurements[0].from_db assert not measurer.has_recalculated diff --git a/tests/models/test_objects.py b/tests/models/test_objects.py index 192deae6..db23200a 100644 --- a/tests/models/test_objects.py +++ b/tests/models/test_objects.py @@ -38,8 +38,8 @@ def test_lightcurves_from_measurements(sim_lightcurves): for m in lc: measured_flux.append(m.flux_apertures[3] - m.bkg_mean * m.area_apertures[3]) - expected_flux.append(m.sources.data['flux'][m.cutouts.index_in_sources]) - expected_error.append(m.sources.data['flux_err'][m.cutouts.index_in_sources]) + expected_flux.append(m.sources.data['flux'][m.index_in_sources]) + expected_error.append(m.sources.data['flux_err'][m.index_in_sources]) assert len(expected_flux) == len(measured_flux) for i in range(len(measured_flux)): diff --git a/tests/models/test_ptf.py b/tests/models/test_ptf.py index 28a8661b..c9625ef9 100644 --- a/tests/models/test_ptf.py +++ b/tests/models/test_ptf.py @@ -24,7 +24,7 @@ def test_ptf_datastore(ptf_datastore): assert isinstance(ptf_datastore.zp, ZeroPoint) assert isinstance(ptf_datastore.sub_image, Image) assert isinstance(ptf_datastore.detections, SourceList) - assert all([isinstance(c, Cutouts) for c in ptf_datastore.cutouts]) + assert isinstance(ptf_datastore.cutouts, Cutouts) assert all([isinstance(m, Measurements) for m in ptf_datastore.measurements]) # using that bad row of pixels from the mask image diff --git a/tests/pipeline/test_measuring.py b/tests/pipeline/test_measuring.py index 24457dc4..0b1a6487 100644 --- a/tests/pipeline/test_measuring.py +++ b/tests/pipeline/test_measuring.py @@ -16,85 +16,80 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): measurer.pars.bad_pixel_exclude = ['saturated'] # ignore saturated pixels measurer.pars.bad_flag_exclude = ['satellite'] # ignore satellite cutouts - sz = decam_cutouts[0].sub_data.shape - fwhm = decam_cutouts[0].sources.image.get_psf().fwhm_pixels + decam_cutouts.load_all_co_data() + sz = decam_cutouts.co_dict["source_index_0"]["sub_data"].shape + fwhm = decam_cutouts.sources.image.get_psf().fwhm_pixels # clear any flags for the fake data we are using for i in range(14): - decam_cutouts[i].sub_flags = np.zeros_like(decam_cutouts[i].sub_flags) + decam_cutouts.co_dict[f"source_index_{i}"]["sub_flags"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_{i}"]["sub_flags"]) # decam_cutouts[i].filepath = None # make sure the cutouts don't re-load the original data # delta function - decam_cutouts[0].sub_data = np.zeros_like(decam_cutouts[0].sub_data) - decam_cutouts[0].sub_data[sz[0] // 2, sz[1] // 2] = 100.0 + decam_cutouts.co_dict[f"source_index_0"]["sub_data"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_0"]["sub_data"]) + decam_cutouts.co_dict[f"source_index_0"]["sub_data"][sz[0] // 2, sz[1] // 2] = 100.0 # shifted delta function - decam_cutouts[1].sub_data = np.zeros_like(decam_cutouts[0].sub_data) - decam_cutouts[1].sub_data[sz[0] // 2 + 2, sz[1] // 2 + 3] = 200.0 + decam_cutouts.co_dict[f"source_index_1"]["sub_data"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_0"]["sub_data"]) + decam_cutouts.co_dict[f"source_index_1"]["sub_data"][sz[0] // 2 + 2, sz[1] // 2 + 3] = 200.0 # gaussian - decam_cutouts[2].sub_data = make_gaussian(imsize=sz[0], sigma_x=fwhm / 2.355, norm=1) * 1000 + decam_cutouts.co_dict[f"source_index_2"]["sub_data"] = make_gaussian(imsize=sz[0], sigma_x=fwhm / 2.355, norm=1) * 1000 # shifted gaussian - decam_cutouts[3].sub_data = make_gaussian( + decam_cutouts.co_dict[f"source_index_3"]["sub_data"] = make_gaussian( imsize=sz[0], sigma_x=fwhm / 2.355, norm=1, offset_x=-2, offset_y=-3 ) * 500 # dipole - decam_cutouts[4].sub_data = np.zeros_like(decam_cutouts[4].sub_data) - decam_cutouts[4].sub_data += make_gaussian( + decam_cutouts.co_dict[f"source_index_4"]["sub_data"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_4"]["sub_data"]) + decam_cutouts.co_dict[f"source_index_4"]["sub_data"] += make_gaussian( imsize=sz[0], sigma_x=fwhm / 2.355, norm=1, offset_x=-1, offset_y=-0.8 ) * 500 - decam_cutouts[4].sub_data -= make_gaussian( + decam_cutouts.co_dict[f"source_index_4"]["sub_data"] -= make_gaussian( imsize=sz[0], sigma_x=fwhm / 2.355, norm=1, offset_x=1, offset_y=0.8 ) * 500 # shifted gaussian with noise - decam_cutouts[5].sub_data = decam_cutouts[3].sub_data + np.random.normal(0, 1, size=sz) + decam_cutouts.co_dict[f"source_index_5"]["sub_data"] = decam_cutouts.co_dict[f"source_index_3"]["sub_data"] + np.random.normal(0, 1, size=sz) # dipole with noise - decam_cutouts[6].sub_data = decam_cutouts[4].sub_data + np.random.normal(0, 1, size=sz) + decam_cutouts.co_dict[f"source_index_6"]["sub_data"] = decam_cutouts.co_dict[f"source_index_4"]["sub_data"] + np.random.normal(0, 1, size=sz) # delta function with bad pixel - decam_cutouts[7].sub_data = np.zeros_like(decam_cutouts[0].sub_data) - decam_cutouts[7].sub_data[sz[0] // 2, sz[1] // 2] = 100.0 - decam_cutouts[7].sub_flags[sz[0] // 2 + 2, sz[1] // 2 + 2] = 1 # bad pixel + decam_cutouts.co_dict[f"source_index_7"]["sub_data"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_0"]["sub_data"]) + decam_cutouts.co_dict[f"source_index_7"]["sub_data"][sz[0] // 2, sz[1] // 2] = 100.0 + decam_cutouts.co_dict[f"source_index_7"]["sub_flags"][sz[0] // 2 + 2, sz[1] // 2 + 2] = 1 # bad pixel # delta function with bad pixel and saturated pixel - decam_cutouts[8].sub_data = np.zeros_like(decam_cutouts[0].sub_data) - decam_cutouts[8].sub_data[sz[0] // 2, sz[1] // 2] = 100.0 - decam_cutouts[8].sub_flags[sz[0] // 2 + 2, sz[1] // 2 + 1] = 1 # bad pixel - decam_cutouts[8].sub_flags[sz[0] // 2 - 2, sz[1] // 2 + 1] = 4 # saturated should be ignored! + decam_cutouts.co_dict[f"source_index_8"]["sub_data"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_0"]["sub_data"]) + decam_cutouts.co_dict[f"source_index_8"]["sub_data"][sz[0] // 2, sz[1] // 2] = 100.0 + decam_cutouts.co_dict[f"source_index_8"]["sub_flags"][sz[0] // 2 + 2, sz[1] // 2 + 1] = 1 # bad pixel + decam_cutouts.co_dict[f"source_index_8"]["sub_flags"][sz[0] // 2 - 2, sz[1] // 2 + 1] = 4 # saturated should be ignored! # delta function with offset that makes it far from the bad pixel - decam_cutouts[9].sub_data = np.zeros_like(decam_cutouts[0].sub_data) - decam_cutouts[9].sub_data[sz[0] // 2 + 3, sz[1] // 2 + 3] = 100.0 - decam_cutouts[9].sub_flags[sz[0] // 2 - 2, sz[1] // 2 - 2] = 1 # bad pixel + decam_cutouts.co_dict[f"source_index_9"]["sub_data"] = np.zeros_like(decam_cutouts.co_dict[f"source_index_0"]["sub_data"]) + decam_cutouts.co_dict[f"source_index_9"]["sub_data"][sz[0] // 2 + 3, sz[1] // 2 + 3] = 100.0 + decam_cutouts.co_dict[f"source_index_9"]["sub_flags"][sz[0] // 2 - 2, sz[1] // 2 - 2] = 1 # bad pixel # gaussian that is too wide - decam_cutouts[10].sub_data = make_gaussian(imsize=sz[0], sigma_x=fwhm / 2.355 * 2, norm=1) * 1000 - decam_cutouts[10].sub_data += np.random.normal(0, 1, size=sz) + decam_cutouts.co_dict[f"source_index_10"]["sub_data"] = make_gaussian(imsize=sz[0], sigma_x=fwhm / 2.355 * 2, norm=1) * 1000 + decam_cutouts.co_dict[f"source_index_10"]["sub_data"] += np.random.normal(0, 1, size=sz) # streak - decam_cutouts[11].sub_data = make_gaussian(imsize=sz[0], sigma_x=fwhm / 2.355, sigma_y=20, rotation=25, norm=1) - decam_cutouts[11].sub_data *= 1000 - decam_cutouts[11].sub_data += np.random.normal(0, 1, size=sz) - - # a regular cutout but we'll put some bad flag on the cutout - decam_cutouts[12].badness = 'cosmic ray' - - # a regular cutout with a bad flag that we are ignoring: - decam_cutouts[13].badness = 'satellite' + decam_cutouts.co_dict[f"source_index_11"]["sub_data"] = make_gaussian(imsize=sz[0], sigma_x=fwhm / 2.355, sigma_y=20, rotation=25, norm=1) + decam_cutouts.co_dict[f"source_index_11"]["sub_data"] *= 1000 + decam_cutouts.co_dict[f"source_index_11"]["sub_data"] += np.random.normal(0, 1, size=sz) # run the measurer ds = measurer.run(decam_cutouts) - assert len(ds.all_measurements) == len(ds.cutouts) + assert len(ds.all_measurements) == len(ds.cutouts.co_dict) # verify all scores have been assigned for score in measurer.pars.analytical_cuts: assert score in ds.measurements[0].disqualifier_scores - m = ds.all_measurements[0] # delta function + m = [m for m in ds.all_measurements if m.index_in_sources == 0][0] # delta function assert m.disqualifier_scores['negatives'] == 0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] < 0.01 @@ -108,7 +103,7 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): for i in range(3): # check only the last apertures, that are smaller than cutout square assert m.area_apertures[i] == pytest.approx(np.pi * (m.aper_radii[i] + 0.5) ** 2, rel=0.1) - m = ds.all_measurements[1] # shifted delta function + m = [m for m in ds.all_measurements if m.index_in_sources == 1][0] # shifted delta function assert m.disqualifier_scores['negatives'] == 0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] == pytest.approx(np.sqrt(2 ** 2 + 3 ** 2), abs=0.1) @@ -120,7 +115,7 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): assert m.bkg_mean == 0 assert m.bkg_std == 0 - m = ds.all_measurements[2] # gaussian + m = [m for m in ds.all_measurements if m.index_in_sources == 2][0] # gaussian assert m.disqualifier_scores['negatives'] < 1.0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] < 0.1 @@ -136,7 +131,7 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): # TODO: add test for PSF flux when it is implemented - m = ds.all_measurements[3] # shifted gaussian + m = [m for m in ds.all_measurements if m.index_in_sources == 3][0] # shifted gaussian assert m.disqualifier_scores['negatives'] < 1.0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] == pytest.approx(np.sqrt(2 ** 2 + 3 ** 2), abs=1.0) @@ -149,7 +144,7 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): assert m.bkg_mean == pytest.approx(0, abs=0.01) assert m.bkg_std == pytest.approx(0, abs=0.01) - m = ds.all_measurements[4] # dipole + m = [m for m in ds.all_measurements if m.index_in_sources == 4][0] # dipole assert m.disqualifier_scores['negatives'] == pytest.approx(1.0, abs=0.1) assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] > 100 @@ -161,7 +156,7 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): assert m.bkg_std == 0 assert m.bkg_std == 0 - m = ds.all_measurements[5] # shifted gaussian with noise + m = [m for m in ds.all_measurements if m.index_in_sources == 5][0] # shifted gaussian with noise assert m.disqualifier_scores['negatives'] < 1.0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] == pytest.approx(np.sqrt(2 ** 2 + 3 ** 2), rel=0.1) @@ -172,33 +167,33 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): for i in range(1, len(m.flux_apertures)): assert m.flux_apertures[i] == pytest.approx(500, rel=0.1) - m = ds.all_measurements[6] # dipole with noise + m = [m for m in ds.all_measurements if m.index_in_sources == 6][0] # dipole with noise assert m.disqualifier_scores['negatives'] == pytest.approx(1.0, abs=0.2) assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] > 1 assert m.disqualifier_scores['filter bank'] > 0 - m = ds.all_measurements[7] # delta function with bad pixel + m = [m for m in ds.all_measurements if m.index_in_sources == 7][0] # delta function with bad pixel assert m.disqualifier_scores['negatives'] == 0 assert m.disqualifier_scores['bad pixels'] == 1 assert m.disqualifier_scores['offsets'] < 0.01 assert m.disqualifier_scores['filter bank'] == 1 assert m.get_filter_description() == f'PSF mismatch (FWHM= 0.25 x {fwhm:.2f})' - m = ds.all_measurements[8] # delta function with bad pixel and saturated pixel + m = [m for m in ds.all_measurements if m.index_in_sources == 8][0] # delta function with bad pixel and saturated pixel assert m.disqualifier_scores['negatives'] == 0 assert m.disqualifier_scores['bad pixels'] == 1 # we set to ignore the saturated pixel! assert m.disqualifier_scores['offsets'] < 0.01 assert m.disqualifier_scores['filter bank'] == 1 assert m.get_filter_description() == f'PSF mismatch (FWHM= 0.25 x {fwhm:.2f})' - m = ds.all_measurements[9] # delta function with offset that makes it far from the bad pixel + m = [m for m in ds.all_measurements if m.index_in_sources == 9][0] # delta function with offset that makes it far from the bad pixel assert m.disqualifier_scores['negatives'] == 0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] == pytest.approx(np.sqrt(3 ** 2 + 3 ** 2), abs=0.1) assert m.disqualifier_scores['filter bank'] == 1 - m = ds.all_measurements[10] # gaussian that is too wide + m = [m for m in ds.all_measurements if m.index_in_sources == 10][0] # gaussian that is too wide assert m.disqualifier_scores['negatives'] < 1.0 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] < 0.5 @@ -213,7 +208,7 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): assert m.bkg_mean == pytest.approx(0, abs=0.2) assert m.bkg_std == pytest.approx(1.0, abs=0.2) - m = ds.all_measurements[11] # streak + m = [m for m in ds.all_measurements if m.index_in_sources == 11][0] # streak assert m.disqualifier_scores['negatives'] < 0.5 assert m.disqualifier_scores['bad pixels'] == 0 assert m.disqualifier_scores['offsets'] < 0.7 @@ -222,32 +217,6 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): assert m.bkg_mean < 0.5 assert m.bkg_std < 3.0 - m = ds.all_measurements[12] # regular cutout with a bad flag - assert m.disqualifier_scores['bad_flag'] == 2 ** 41 # this is the bit for 'cosmic ray' - - m = ds.all_measurements[13] # regular cutout with a bad flag that we are ignoring - assert m.disqualifier_scores['bad_flag'] == 0 # we've included the satellite flag in the ignore list - - # check that coordinates have been modified: - for i in range(14): - m = ds.all_measurements[i] - if m.offset_x != 0 and m.offset_y != 0: - assert m.ra != m.cutouts.ra - assert m.dec != m.cutouts.dec - - -def test_propagate_badness(decam_datastore): - ds = decam_datastore - with SmartSession() as session: - ds.measurements[0].badness = 'cosmic ray' - # find the index of the cutout that corresponds to the measurement - idx = [i for i, c in enumerate(ds.cutouts) if c.id == ds.measurements[0].cutouts_id][0] - ds.cutouts[idx].badness = 'cosmic ray' - ds.cutouts[idx].update_downstream_badness(session=session) - m = session.merge(ds.measurements[0]) - - assert m.badness == 'cosmic ray' # note that this does not change disqualifier_scores! - def test_warnings_and_exceptions(decam_datastore, measurer): measurer.pars.inject_warnings = 1 @@ -264,4 +233,3 @@ def test_warnings_and_exceptions(decam_datastore, measurer): ds.reraise() assert "Exception injected by pipeline parameters in process 'measuring'." in str(excinfo.value) ds.read_exception() - diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 4ab0bfb3..556233ca 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -119,21 +119,19 @@ def check_datastore_and_database_have_everything(exp_id, sec_id, ref_id, session assert det is not None assert ds.detections.id == det.id - # find the Cutouts list + # find the Cutouts cutouts = session.scalars( sa.select(Cutouts).where( Cutouts.sources_id == det.id, - Cutouts.provenance_id == ds.cutouts[0].provenance_id, + Cutouts.provenance_id == ds.cutouts.provenance_id, ) - ).all() - assert len(cutouts) > 0 - assert len(ds.cutouts) == len(cutouts) - assert set([c.id for c in ds.cutouts]) == set([c.id for c in cutouts]) + ).first() + assert ds.cutouts.id == cutouts.id # Measurements measurements = session.scalars( sa.select(Measurements).where( - Measurements.cutouts_id.in_([c.id for c in cutouts]), + Measurements.cutouts_id == cutouts.id, Measurements.provenance_id == ds.measurements[0].provenance_id, ) ).all() @@ -302,8 +300,7 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali assert ds.zp._upstream_bitflag == 2 assert ds.sub_image._upstream_bitflag == 2 assert ds.detections._upstream_bitflag == 2 - for cutout in ds.cutouts: # cutouts is a list of cutout objects - assert cutout._upstream_bitflag == 2 + assert ds.cutouts._upstream_bitflag == 2 # test part 2: Add a second bitflag partway through and check it propagates to downstreams @@ -325,9 +322,10 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali assert ds.zp._upstream_bitflag == desired_bitflag assert ds.sub_image._upstream_bitflag == desired_bitflag assert ds.detections._upstream_bitflag == desired_bitflag - for cutout in ds.cutouts: - assert cutout._upstream_bitflag == desired_bitflag - assert ds.image.bitflag == 2 # not in the downstream of sources + assert ds.cutouts._upstream_bitflag == desired_bitflag + for m in ds.measurements: + assert m._upstream_bitflag == desired_bitflag + assert ds.image.bitflag == 2 # not in the downstream of sources # test part 3: test update_downstream_badness() function by adding and removing flags # and observing propagation @@ -354,8 +352,9 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali assert ds.zp.bitflag == desired_bitflag assert ds.sub_image.bitflag == desired_bitflag assert ds.detections.bitflag == desired_bitflag - for cutout in ds.cutouts: - assert cutout.bitflag == desired_bitflag + assert ds.cutouts.bitflag == desired_bitflag + for m in ds.measurements: + assert m.bitflag == desired_bitflag # remove the bitflag and check that it disappears in downstreams ds.image._bitflag = 0 # remove 'bad subtraction' @@ -371,8 +370,9 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali assert ds.zp.bitflag == desired_bitflag assert ds.sub_image.bitflag == desired_bitflag assert ds.detections.bitflag == desired_bitflag - for cutout in ds.cutouts: - assert cutout.bitflag == desired_bitflag + assert ds.cutouts.bitflag == desired_bitflag + for m in ds.measurements: + assert m.bitflag == desired_bitflag finally: if 'ds' in locals(): @@ -425,14 +425,10 @@ def test_get_upstreams_and_downstreams(decam_exposure, decam_reference, decam_de ds.zp.id, ]) assert [upstream.id for upstream in ds.detections.get_upstreams(session)] == [ds.sub_image.id] - for cutout in ds.cutouts: - assert [upstream.id for upstream in cutout.get_upstreams(session)] == [ds.detections.id] - # measurements are a challenge to make sure the *right* measurement is with the right cutout - # for the time being, check that the measurements upstream is one of the cutouts - cutout_ids = np.unique([cutout.id for cutout in ds.cutouts]) + assert [upstream.id for upstream in ds.cutouts.get_upstreams(session)] == [ds.detections.id] + for measurement in ds.measurements: - m_upstream_ids = np.array([upstream.id for upstream in measurement.get_upstreams(session)]) - assert np.all(np.isin(m_upstream_ids, cutout_ids)) + assert [upstream.id for upstream in measurement.get_upstreams(session)] == [ds.cutouts.id] # test get_downstreams assert [downstream.id for downstream in ds.exposure.get_downstreams(session)] == [ds.image.id] @@ -449,14 +445,9 @@ def test_get_upstreams_and_downstreams(decam_exposure, decam_reference, decam_de assert [downstream.id for downstream in ds.wcs.get_downstreams(session)] == [ds.sub_image.id] assert [downstream.id for downstream in ds.zp.get_downstreams(session)] == [ds.sub_image.id] assert [downstream.id for downstream in ds.sub_image.get_downstreams(session)] == [ds.detections.id] - assert np.all(np.isin([downstream.id for downstream in ds.detections.get_downstreams(session)], cutout_ids)) - # basic test: check the downstreams of cutouts is one of the measurements - measurement_ids = np.unique([measurement.id for measurement in ds.measurements]) - for cutout in ds.cutouts: - c_downstream_ids = [downstream.id for downstream in cutout.get_downstreams(session)] - assert np.all(np.isin(c_downstream_ids, measurement_ids)) - for measurement in ds.measurements: - assert [downstream.id for downstream in measurement.get_downstreams(session)] == [] + assert [downstream.id for downstream in ds.detections.get_downstreams(session)] == [ds.cutouts.id] + measurement_ids = set([measurement.id for measurement in ds.measurements]) + assert set([downstream.id for downstream in ds.cutouts.get_downstreams(session)]) == measurement_ids finally: if 'ds' in locals(): @@ -478,8 +469,8 @@ def test_datastore_delete_everything(decam_datastore): sub_paths = sub.get_fullpath(as_list=True) det = decam_datastore.detections det_path = det.get_fullpath() - cutouts_list = decam_datastore.cutouts - cutouts_file_path = cutouts_list[0].get_fullpath() + cutouts = decam_datastore.cutouts + cutouts_file_path = cutouts.get_fullpath() measurements_list = decam_datastore.measurements # make sure we can delete everything @@ -508,7 +499,7 @@ def test_datastore_delete_everything(decam_datastore): assert session.scalars(sa.select(PSF).where(PSF.id == psf.id)).first() is None assert session.scalars(sa.select(Image).where(Image.id == sub.id)).first() is None assert session.scalars(sa.select(SourceList).where(SourceList.id == det.id)).first() is None - assert session.scalars(sa.select(Cutouts).where(Cutouts.id == cutouts_list[0].id)).first() is None + assert session.scalars(sa.select(Cutouts).where(Cutouts.id == cutouts.id)).first() is None if len(measurements_list) > 0: assert session.scalars( sa.select(Measurements).where(Measurements.id == measurements_list[0].id) @@ -531,7 +522,7 @@ def test_provenance_tree(pipeline_for_tests, decam_exposure, decam_datastore, de assert ds.zp.provenance_id == provs['extraction'].id assert ds.sub_image.provenance_id == provs['subtraction'].id assert ds.detections.provenance_id == provs['detection'].id - assert ds.cutouts[0].provenance_id == provs['cutting'].id + assert ds.cutouts.provenance_id == provs['cutting'].id assert ds.measurements[0].provenance_id == provs['measuring'].id with SmartSession() as session: From 02db29963d04416579e7d9e71371bc0e6913583f Mon Sep 17 00:00:00 2001 From: Guy Nir <37179063+guynir42@users.noreply.github.com> Date: Mon, 1 Jul 2024 23:38:20 +0300 Subject: [PATCH 3/3] Reference sets (#311) --- .github/workflows/run-improc-tests.yml | 2 + .github/workflows/run-model-tests-1.yml | 2 + .github/workflows/run-model-tests-2.yml | 2 + .github/workflows/run-pipeline-tests-1.yml | 2 + .github/workflows/run-pipeline-tests-2.yml | 3 + .github/workflows/run-util-tests.yml | 3 +- ..._07_01_1135-370933973646_reference_sets.py | 69 ++ default_config.yaml | 180 +++-- docs/coadd_pipeline.md | 3 - docs/index.rst | 2 +- docs/overview.md | 4 +- docs/references.md | 107 +++ improc/alignment.py | 46 +- improc/inpainting.py | 2 +- improc/photometry.py | 3 +- improc/zogy.py | 3 + models/background.py | 16 +- models/base.py | 57 ++ models/cutouts.py | 1 + models/enums_and_bitflags.py | 14 + models/exposure.py | 22 +- models/image.py | 370 +++++++++- models/instrument.py | 1 + models/measurements.py | 11 +- models/provenance.py | 10 +- models/psf.py | 16 +- models/reference.py | 119 ++- models/refset.py | 69 ++ models/report.py | 15 +- models/source_list.py | 16 +- models/world_coordinates.py | 16 +- models/zero_point.py | 14 +- pipeline/astro_cal.py | 38 +- pipeline/backgrounding.py | 28 +- pipeline/coaddition.py | 101 +-- pipeline/cutting.py | 34 +- pipeline/data_store.py | 250 +++---- pipeline/detection.py | 60 +- pipeline/measuring.py | 78 +- pipeline/parameters.py | 6 + pipeline/photo_cal.py | 47 +- pipeline/preprocessing.py | 7 +- pipeline/ref_maker.py | 566 +++++++++++++++ pipeline/subtraction.py | 54 +- pipeline/top_level.py | 316 ++++---- tests/conftest.py | 32 +- tests/fixtures/datastore_factory.py | 605 ++++++++++++++++ tests/fixtures/decam.py | 156 ++-- tests/fixtures/pipeline_objects.py | 662 +---------------- tests/fixtures/ptf.py | 129 +++- tests/fixtures/simulated.py | 2 +- tests/improc/test_simulator.py | 9 +- tests/improc/test_sky_flat.py | 2 +- tests/improc/test_zogy.py | 4 +- tests/models/test_base.py | 49 +- tests/models/test_cutouts.py | 11 +- tests/models/test_decam.py | 27 +- tests/models/test_image.py | 596 +-------------- tests/models/test_image_propagation.py | 323 +++++++++ tests/models/test_image_querying.py | 685 ++++++++++++++++++ tests/models/test_measurements.py | 33 +- tests/models/test_psf.py | 9 +- tests/models/test_reports.py | 21 +- tests/models/test_source_list.py | 3 +- tests/models/test_world_coordinates.py | 14 +- tests/pipeline/test_astro_cal.py | 17 +- tests/pipeline/test_backgrounding.py | 13 +- tests/pipeline/test_coaddition.py | 11 +- .../test_compare_sextractor_to_photutils.py | 3 +- tests/pipeline/test_cutting.py | 13 +- tests/pipeline/test_detection.py | 14 +- tests/pipeline/test_extraction.py | 16 +- tests/pipeline/test_making_references.py | 258 +++++++ tests/pipeline/test_measuring.py | 13 +- tests/pipeline/test_photo_cal.py | 13 +- tests/pipeline/test_pipeline.py | 195 ++--- tests/pipeline/test_preprocessing.py | 14 +- tests/pipeline/test_reffinding.py | 51 -- tests/pipeline/test_subtraction.py | 20 +- tests/util/test_radec.py | 10 + util/radec.py | 16 +- util/util.py | 18 +- 82 files changed, 4584 insertions(+), 2268 deletions(-) create mode 100644 alembic/versions/2024_07_01_1135-370933973646_reference_sets.py delete mode 100644 docs/coadd_pipeline.md create mode 100644 docs/references.md create mode 100644 models/refset.py create mode 100644 pipeline/ref_maker.py create mode 100644 tests/fixtures/datastore_factory.py create mode 100644 tests/models/test_image_propagation.py create mode 100644 tests/models/test_image_querying.py create mode 100644 tests/pipeline/test_making_references.py delete mode 100644 tests/pipeline/test_reffinding.py diff --git a/.github/workflows/run-improc-tests.yml b/.github/workflows/run-improc-tests.yml index ba39f935..1793e015 100644 --- a/.github/workflows/run-improc-tests.yml +++ b/.github/workflows/run-improc-tests.yml @@ -59,10 +59,12 @@ jobs: - name: run test run: | + # try to save HDD space on the runner by removing some unneeded stuff # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=tests/improc docker compose run runtests diff --git a/.github/workflows/run-model-tests-1.yml b/.github/workflows/run-model-tests-1.yml index fb610eee..1b1e1f62 100644 --- a/.github/workflows/run-model-tests-1.yml +++ b/.github/workflows/run-model-tests-1.yml @@ -59,10 +59,12 @@ jobs: - name: run test run: | + # try to save HDD space on the runner by removing some unneeded stuff # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=$(ls tests/models/test_{a..l}*.py) docker compose run runtests diff --git a/.github/workflows/run-model-tests-2.yml b/.github/workflows/run-model-tests-2.yml index 3158b7ba..037a533f 100644 --- a/.github/workflows/run-model-tests-2.yml +++ b/.github/workflows/run-model-tests-2.yml @@ -59,10 +59,12 @@ jobs: - name: run test run: | + # try to save HDD space on the runner by removing some unneeded stuff # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=$(ls tests/models/test_{m..z}*.py) docker compose run runtests diff --git a/.github/workflows/run-pipeline-tests-1.yml b/.github/workflows/run-pipeline-tests-1.yml index 702fc61e..42d73646 100644 --- a/.github/workflows/run-pipeline-tests-1.yml +++ b/.github/workflows/run-pipeline-tests-1.yml @@ -59,10 +59,12 @@ jobs: - name: run test run: | + # try to save HDD space on the runner by removing some uneeded stuff # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=$(ls tests/pipeline/test_{a..o}*.py) docker compose run runtests diff --git a/.github/workflows/run-pipeline-tests-2.yml b/.github/workflows/run-pipeline-tests-2.yml index 461739d8..02ba99d8 100644 --- a/.github/workflows/run-pipeline-tests-2.yml +++ b/.github/workflows/run-pipeline-tests-2.yml @@ -59,10 +59,13 @@ jobs: - name: run test run: | + # try to save HDD space on the runner by removing some unneeded stuff # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" + shopt -s nullglob TEST_SUBFOLDER=$(ls tests/pipeline/test_{p..z}*.py) docker compose run runtests + diff --git a/.github/workflows/run-util-tests.yml b/.github/workflows/run-util-tests.yml index 591bc250..5e0ef6bd 100644 --- a/.github/workflows/run-util-tests.yml +++ b/.github/workflows/run-util-tests.yml @@ -59,10 +59,11 @@ jobs: - name: run test run: | + # try to save HDD space on the runner by removing some unneeded stuff # ref: https://github.com/actions/runner-images/issues/2840#issuecomment-790492173 sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" - shopt -s nullglob + TEST_SUBFOLDER=tests/util docker compose run runtests diff --git a/alembic/versions/2024_07_01_1135-370933973646_reference_sets.py b/alembic/versions/2024_07_01_1135-370933973646_reference_sets.py new file mode 100644 index 00000000..164b54d3 --- /dev/null +++ b/alembic/versions/2024_07_01_1135-370933973646_reference_sets.py @@ -0,0 +1,69 @@ +"""reference sets + +Revision ID: 370933973646 +Revises: a375526c8260 +Create Date: 2024-06-23 11:35:43.941095 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '370933973646' +down_revision = '7384c6d07485' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('refsets', + sa.Column('name', sa.Text(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('upstream_hash', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('modified', sa.DateTime(), nullable=False), + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_refsets_created_at'), 'refsets', ['created_at'], unique=False) + op.create_index(op.f('ix_refsets_id'), 'refsets', ['id'], unique=False) + op.create_index(op.f('ix_refsets_name'), 'refsets', ['name'], unique=True) + op.create_index(op.f('ix_refsets_upstream_hash'), 'refsets', ['upstream_hash'], unique=False) + op.create_table('refset_provenance_association', + sa.Column('provenance_id', sa.Text(), nullable=False), + sa.Column('refset_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['provenance_id'], ['provenances.id'], name='refset_provenances_association_provenance_id_fkey', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['refset_id'], ['refsets.id'], name='refsets_provenances_association_refset_id_fkey', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('provenance_id', 'refset_id') + ) + op.drop_index('ix_refs_validity_end', table_name='refs') + op.drop_index('ix_refs_validity_start', table_name='refs') + op.drop_column('refs', 'validity_start') + op.drop_column('refs', 'validity_end') + + op.add_column('images', sa.Column('airmass', sa.REAL(), nullable=True)) + op.create_index(op.f('ix_images_airmass'), 'images', ['airmass'], unique=False) + op.add_column('exposures', sa.Column('airmass', sa.REAL(), nullable=True)) + op.create_index(op.f('ix_exposures_airmass'), 'exposures', ['airmass'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_images_airmass'), table_name='images') + op.drop_column('images', 'airmass') + op.drop_index(op.f('ix_exposures_airmass'), table_name='exposures') + op.drop_column('exposures', 'airmass') + op.add_column('refs', sa.Column('validity_end', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) + op.add_column('refs', sa.Column('validity_start', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) + op.create_index('ix_refs_validity_start', 'refs', ['validity_start'], unique=False) + op.create_index('ix_refs_validity_end', 'refs', ['validity_end'], unique=False) + op.drop_table('refset_provenance_association') + op.drop_index(op.f('ix_refsets_upstream_hash'), table_name='refsets') + op.drop_index(op.f('ix_refsets_name'), table_name='refsets') + op.drop_index(op.f('ix_refsets_id'), table_name='refsets') + op.drop_index(op.f('ix_refsets_created_at'), table_name='refsets') + op.drop_table('refsets') + # ### end Alembic commands ### diff --git a/default_config.yaml b/default_config.yaml index 68658624..205ad576 100644 --- a/default_config.yaml +++ b/default_config.yaml @@ -76,13 +76,18 @@ catalog_gaiadr3: # For documentation on the parameters, see the Parameters subclass # in the file that defines each part of the pipeline -pipeline: {} +pipeline: + # save images and their products before the stage where we look for a reference and make a subtraction + save_before_subtraction: true + # automatically save all the products at the end of the pipeline run + save_at_finish: true preprocessing: + # these steps need to be done on the images: either they came like that or we do it in the pipeline steps_required: [ 'overscan', 'linearity', 'flat', 'fringe' ] extraction: - sources: + sources: # this part of the extraction parameters is for finding sources and calculating the PSF method: sextractor measure_psf: true apertures: [1.0, 2.0, 3.0, 5.0] @@ -92,80 +97,78 @@ extraction: separation_fwhm: 1.0 threshold: 3.0 subtraction: false - bg: + bg: # this part of the extraction parameters is for estimating the background format: map method: sep poly_order: 1 sep_box_size: 128 sep_filt_size: 3 - wcs: + wcs: # this part of the extraction parameters is for finding the world coordinates system (astrometric calibration) cross_match_catalog: gaia_dr3 solution_method: scamp - max_catalog_mag: [20.0] - mag_range_catalog: 4.0 + max_catalog_mag: [22.0] + mag_range_catalog: 6.0 min_catalog_stars: 50 max_sources_to_use: [2000, 1000, 500, 200] - zp: + zp: # this part of the extraction parameters is for finding the zero point (photometric calibration) cross_match_catalog: gaia_dr3 - max_catalog_mag: [20.0] - mag_range_catalog: 4.0 + max_catalog_mag: [22.0] + mag_range_catalog: 6.0 min_catalog_stars: 50 +# how to do the subtractions subtraction: method: zogy + # set refset to null to only make references (no subtraction will happen). + # to start running subtractions, first make a ref set and use the name of that + # in this field. + refset: null alignment: method: swarp to_index: new +# how to extract sources (detections) from the subtration image detection: - subtraction: true + subtraction: true # this sets up the Detector object to run on subtraction images method: filter # when using ZOGY subtraction, detection method must be "filter"! threshold: 5.0 +# how to make the cutouts cutting: cutout_size: 25 +# how to measure things like fluxes and centroids, and make analytical cuts measuring: annulus_radii: [10, 15] annulus_units: pixels use_annulus_for_centroids: true + # TODO: remove these in favor of the thresholds dict (and put None to not threshold any one of them)? Issue #319 analytical_cuts: ['negatives', 'bad pixels', 'offsets', 'filter bank'] - outlier_sigma: 3.0 - bad_pixel_radius: 3.0 - bad_pixel_exclude: [] - streak_filter_angle_step: 5.0 - width_filter_multipliers: [0.25, 2.0, 5.0, 10.0] - association_radius: 2.0 - thresholds: + outlier_sigma: 3.0 # how many times the noise rms counts as a positive/negative outlier + bad_pixel_radius: 3.0 # how far from the centroid counts as having a bad pixel + bad_pixel_exclude: [] # which types of bad pixels are ok to have near the source + streak_filter_angle_step: 5.0 # how many degrees to step through the angles for the streak filter + width_filter_multipliers: [0.25, 2.0, 5.0, 10.0] # try different width, if they trigger a high S/N, disqualify it + association_radius: 2.0 # when matching sources, how close counts as being from the same underlying object + thresholds: # any of the analytical cuts that score above these thresholds will be disqualified negatives: 0.3 bad pixels: 1 offsets: 5.0 filter bank: 1 - deletion_thresholds: + deletion_thresholds: # any of the analytical cuts that score above these thresholds will be deleted negatives: 0.3 bad pixels: 1 offsets: 5.0 filter bank: 1 -# Specific configuration for specific instruments. -# Instruments should override the two defaults from -# instrument_default; they may add additional -# configuration that their code needs. - -instrument_default: - calibratorset: nightly - flattype: sky - -# Config for astromatic utilities (sextractor, scamp, swarp, psfex) -astromatic: - # An absolute path to where astromatic config files are - config_dir: null - # A path relative to models/base/CODE_ROOT where the astromatic - # config files are. Ignored if config_dir is not null - config_subdir: data/astromatic_config - +# use these parameters when running the coaddition pipeline, e.g., for making weekly coadds +# the coaddition pipeline will load the regular configuration first, then the coaddition config: coaddition: + # the pipeline handles any top level configuration (e.g., how to choose images) + pipeline: + # how many days back you'd like to collect the images for + date_range: 7.0 coaddition: method: zogy noise_estimator: sep @@ -181,25 +184,111 @@ coaddition: ignore_flags: 0 # The following are used to override the regular "extraction" parameters extraction: - sources: + sources: # override the regular source and psf extraction parameters measure_psf: true threshold: 3.0 method: sextractor - background_method: zero - # The following are used to override the regular astrometric calibration parameters - wcs: + bg: # override the regular background estimation parameters + format: map + method: sep + wcs: # override the regular astrometric calibration parameters cross_match_catalog: gaia_dr3 solution_method: scamp max_catalog_mag: [22.0] mag_range_catalog: 6.0 min_catalog_stars: 50 - # The following are used to override the regular photometric calibration parameters - zp: + zp: # override the regular photometric calibration parameters cross_match_catalog: gaia_dr3 max_catalog_mag: [22.0] mag_range_catalog: 6.0 min_catalog_stars: 50 +# use these parameters to make references by coadding images +# the reference pipeline will load the regular configuration first, +# then the coaddition config, and only in the end, override with this: +referencing: + maker: # these handles the top level configuration (e.g., how to choose images that go into the reference maker) + name: best_references # the name of the reference set that you want to make references for + allow_append: true # can we add different image load criteria to an existing reference set with this name? + start_time: null # only grab images after this time + end_time: null # only grab images before this time + instrument: null # only grab images from this instrument (if list, will make cross-instrument references) + filter: null # only grab images with this filter + project: null # only grab images with this project + min_airmass: null # only grab images with airmass above this + max_airmass: null # only grab images with airmass below this + min_background: null # only grab images with background rms above this + max_background: null # only grab images with background rms below this + min_seeing: null # only grab images with seeing above this + max_seeing: null # only grab images with seeing below this + min_lim_mag: null # only grab images with limiting magnitude above this + max_lim_mag: null # only grab images with limiting magnitude below this + min_exp_time: null # only grab images with exposure time above this + max_exp_time: null # only grab images with exposure time below this + min_number: 7 # only create a reference if this many images can be found + max_number: 30 # do not add more than this number of images to the reference + seeing_quality_factor: 3.0 # linear coefficient for adding lim_mag and seeing to get the "image quality" + save_new_refs: true # should the new references be saved to disk and committed to the database? + pipeline: # The following are used to override the regular "extraction" parameters + extraction: + sources: # override the regular source and psf extraction parameters + measure_psf: true +# threshold: 3.0 +# method: sextractor +# bg: +# format: map +# method: sep +# poly_order: 1 +# sep_box_size: 128 +# sep_filt_size: 3 +# wcs: # override the regular astrometric calibration parameters +# cross_match_catalog: gaia_dr3 +# solution_method: scamp +# max_catalog_mag: [22.0] +# mag_range_catalog: 6.0 +# min_catalog_stars: 50 +# zp: # override the regular photometric calibration parameters +# cross_match_catalog: gaia_dr3 +# max_catalog_mag: [22.0] +# mag_range_catalog: 6.0 +# min_catalog_stars: 50 + + coaddition: # override the coaddition parameters in the general "coaddition" config + coaddition: # override the way coadds are made, from the general "coaddition" config + method: zogy + extraction: # override the coaddition/regular pipeline config, when extracting for the coadd images + sources: # override the regular source and psf extraction parameters and the coadd extraction parameters + measure_psf: true +# threshold: 3.0 +# method: sextractor +# bg: # override the regular background estimation parameters and the coadd extraction parameters +# format: map +# method: sep +# poly_order: 1 +# sep_box_size: 128 +# sep_filt_size: 3 +# wcs: # override the regular astrometric calibration parameters and the coadd extraction parameters +# cross_match_catalog: gaia_dr3 +# solution_method: scamp +# max_catalog_mag: [22.0] +# mag_range_catalog: 6.0 +# min_catalog_stars: 50 +# zp: # override the regular photometric calibration parameters and the coadd extraction parameters +# cross_match_catalog: gaia_dr3 +# max_catalog_mag: [22.0] +# mag_range_catalog: 6.0 +# min_catalog_stars: 50 + + +# Specific configuration for specific instruments. +# Instruments should override the two defaults from +# instrument_default; they may add additional +# configuration that their code needs. + +instrument_default: + calibratorset: nightly + flattype: sky + # DECam @@ -222,3 +311,12 @@ DECam: fringebase: DECamMasterCal_56876/fringecor/DECam_Master_20131115v1 flatbase: DECamMasterCal_56876/starflat/DECam_Master_20130829v3 bpmbase: DECamMasterCal_56876/bpm/DECam_Master_20140209v2_cd_ + + +# Config for astromatic utilities (sextractor, scamp, swarp, psfex) +astromatic: + # An absolute path to where astromatic config files are + config_dir: null + # A path relative to models/base/CODE_ROOT where the astromatic + # config files are. Ignored if config_dir is not null + config_subdir: data/astromatic_config diff --git a/docs/coadd_pipeline.md b/docs/coadd_pipeline.md deleted file mode 100644 index fbb013f0..00000000 --- a/docs/coadd_pipeline.md +++ /dev/null @@ -1,3 +0,0 @@ -## Coaddition and Reference Pipelines - -TBA \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index bfd16b04..d3842c01 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,7 +17,7 @@ Welcome to SeeChange's documentation! pipeline instruments versioning - coadd_pipeline + references testing contribution miscellaneous diff --git a/docs/overview.md b/docs/overview.md index ba782e1c..8c4ebaca 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -277,8 +277,8 @@ Only parameters that affect the product values are included. The upstreams are other `Provenance` objects defined for the data products that are an input to the current processing step. The flowchart of the different process steps is defined in `pipeline.datastore.UPSTREAM_STEPS`. -E.g., the upstreams for the `subtraction` object are `['preprocessing', 'extraction', 'reference']`. -Note that the `reference` upstream is replaced by the provenances +E.g., the upstreams for the `subtraction` object are `['preprocessing', 'extraction', 'referencing']`. +Note that the `referencing` upstream is replaced by the provenances of the reference's `preprocessing` and `extraction` steps. When a `Provenance` object has all the required inputs, it will produce a hash identifier diff --git a/docs/references.md b/docs/references.md new file mode 100644 index 00000000..4c04c03f --- /dev/null +++ b/docs/references.md @@ -0,0 +1,107 @@ +## Coaddition and Reference Pipeline + +Besides the regular pipeline that processes new images, we also have deep-coadd making and reference making pipelines. + +### Coaddition Pipeline + +To make a deeper coadd image (e.g., a weekly coadd), we must first collect all the images. +This can be done using something like the `Image.query_images()` function. +It essentially wraps around a bunch of SQL filters to get at the required images. + +```python +from models.base import SmartSession +from models.image import Image +stmt = Image.query_images(target='field_034', section_id='N1', filter='R', min_mjd=59000, max_mjd=59007, ...) +with SmartSession() as session: + images = session.scalars(stmt).all() +``` + +Notice that the `image_query()` does not return `Image` objects, but a SQLAlchemy statement. +You can append additional filters to the statement before executing it. +Use the images you've gotten (in any way you like) to make a coadd image. +Make sure each image has its extraction products loaded before starting the coaddition. + + +```python +from pipeline.coaddition import CoaddPipeline + +pipe = CoaddPipeline(...) # use kwargs to override the config parameters +coadd_image = pipe.run(images) # the output image would also have extraction products generated by the pipeline +``` + +Another option to use the coaddition pipeline is to provide a set of named parameters, +specifying the place in the sky, filter, and optional time range. +Note that if the time range is not given, the end time is assumed to be the current time, +and the start time is assumed to be 7 days before the end time (this is a configurable parameter +of the coadd pipeline). +```python +coadd_image = pipe.run(ra=123.4, dec=56.7, filter='g', min_mjd=59000, max_mjd=59007, ...) +``` + +Currently, the coadd pipeline only runs extraction on the coadd image +(i.e., makes source lists, PSF, background, wcs and zp). +However, it may be possible to add the other steps (detection, measuring, etc). +Note that in such a case two things need to happen: +1. The coadd pipeline will have to be given a reference set name, in the subtraction parameters (see below). +2. The coadd pipeline needs to be able to run, and save the coadd image + extraction products, so it can still + be used internally by the reference making pipeline. + +### Reference Pipeline + +The reference making pipeline builds on top of the coadd pipeline but also produces a `Reference` object, +and will place the `Provenance` of the new reference in a `RefSet` object. + +There are two sets of parameters (and code versions!) that determine the uniqueness of a reference set: +1. The data production parameters and code versions that go into producing the individual images, their products, + the coadd image made from them, and the extraction products of the coadd image. Together, these form the + `upstream_hash` of the `RefSet` (this is done by hashing together the coadd image and its products' provenance hashes). + Because those provenances include the individual images' provenances, any change you make along the data production + pipeline (including individual image processing by the regular pipeline or coaddition and extraction by the + coaddition pipeline) will change the `upstream_hash` value. Each `RefSet` has a unique `upstream_hash` value. + This is because those provenances (of the coadd) will go into the upstreams of the subtraction image. + To make sure the output provenance of the subtraction is completely defined by the choice of parameters + (in this case, the choice of the ref-set name) then the `RefSet` must be one-to-one with the `upstream_hash`. +2. The parameters (and code) going into the production of the `Reference` object itself. This includes mostly the + choice of which images to pick for the coaddition. This defines the provenances of the `Reference` object. + There could be multiple `Reference` provenances on a single `RefSet`, so long as they all have the same + upstream provenances (which will go into the `upstream_hash` and into the subtraction upstreams). + It is possible to have a ref-set use one set of search criteria for images and if that fails, use a second set, + usually with more relaxed limits on which images to pick. + For a given place in the sky, you would try to make a reference using set 1, and if that fails, try set 2. + There should always be exactly one reference for each place in the sky for each ref-set. + Once a reference is successfully made, it is always the one that is loaded (unless it is marked as "bad", + which we need to think about what would that mean if we changed a reference mid-survey). + The goal is to prevent changing of the reference in the middle of a survey, but allow the same ref-set to account + for places where the images have lower quality. + In principle, there could be many provenances (for the reference objects) appended to each `RefSet`, but it is + not recommended to add more than one or two. + +To produce a reference, use the `RefMaker` object: + +```python +from pipeline.ref_maker import RefMaker + +maker = RefMaker(maker={'name': 'new_refset', 'instruments': ['PTF']}, ...) # use kwargs to override config parameters +ref = maker.run(ra=123.4, dec=56.7, filter='g') +``` + +Note that we can specify the location in the sky using target/section ID as well. +The ref-maker will check if a `RefSet` already exists with that name, +and will try to append the referencing `Provenance`, if it is not already there. +You can disable appending new provenances to the ref-set by setting `allow_append=False` on the ref-maker. + +To produce the referencing provenance, the ref-maker will first create/load the provenances for the individual images, +including the preprocessing and extraction steps, and the coaddition and extraction on the coadd steps. +These provenances are used as the upstream of the reference provenance, and form the `upstream_hash`. +If the ref-maker has parameters that are inconsistent with the `upstream_hash` of an existing ref-set with the same +name, it will raise an error. + +Note that even though the ref-maker has the parameters (and even a full `Pipeline` object) it does not load exposures, +and it does not run preprocessing and extraction to produce the individual images. Those must already exist on the DB, +and must also have extraction products on the DB, all with the correct provenances. +On the other hand, the ref-maker can run the coaddition pipeline internally, if it doesn't find the coadd image. +If a coadd image and its products already exist with the correct provenance, they are loaded and re-used to produce +a reference. If a coadd and a reference already exist, they will be loaded directly. +Note that a `Reference` object's provenance contains the parameters used to choose which images will be entered into +the coaddition, but it does not name the `RefSet` it belongs to. Thus, multiple ref-sets can use the same Reference +and the same underlying coadd image. \ No newline at end of file diff --git a/improc/alignment.py b/improc/alignment.py index 5eca7df4..560bb33b 100644 --- a/improc/alignment.py +++ b/improc/alignment.py @@ -111,7 +111,7 @@ def __init__(self, **kwargs): critical=False, ) - self.enforce_no_new_attrs = True + self._enforce_no_new_attrs = True self.override( kwargs ) def get_process_name(self): @@ -577,11 +577,45 @@ def run( self, source_image, target_image ): bg.variance = source_image.bg.variance warped_image.bg = bg - warped_image.psf = source_image.psf - warped_image.zp = source_image.zp - warped_image.wcs = source_image.wcs - # TODO: what about SourceList? - # TODO: should these objects be copies of the products, or references to the same objects? + # make sure to copy these as new objects into the warped image + if source_image.sources is not None: + warped_image.sources = source_image.sources.copy() + if source_image.sources.data is not None: + warped_image.sources.data = source_image.sources.data.copy() + + warped_image.sources.image = warped_image + warped_image.sources.provenance = source_image.sources.provenance + warped_image.sources.filepath = None + warped_image.sources.md5sum = None + + if source_image.psf is not None: + warped_image.psf = source_image.psf.copy() + if source_image.psf.data is not None: + warped_image.psf.data = source_image.psf.data.copy() + if source_image.psf.header is not None: + warped_image.psf.header = source_image.psf.header.copy() + if source_image.psf.info is not None: + warped_image.psf.info = source_image.psf.info + + warped_image.psf.image = warped_image + warped_image.psf.provenance = warped_image.provenance + warped_image.psf.filepath = None + warped_image.psf.md5sum = None + + if warped_image.wcs is not None: + warped_image.wcs = source_image.wcs.copy() + if warped_image.wcs._wcs is not None: + warped_image.wcs._wcs = source_image.wcs._wcs.deepcopy() + + warped_image.wcs.sources = warped_image.sources + warped_image.wcs.provenance = source_image.wcs.provenance + warped_image.wcs.filepath = None + warped_image.wcs.md5sum = None + + warped_image.zp = source_image.zp.copy() + warped_image.zp.sources = warped_image.sources + warped_image.zp.provenance = source_image.zp.provenance + else: # Do the warp if self.pars.method == 'swarp': SCLogger.debug( 'Aligning with swarp' ) diff --git a/improc/inpainting.py b/improc/inpainting.py index 668dc4f4..69599437 100644 --- a/improc/inpainting.py +++ b/improc/inpainting.py @@ -58,7 +58,7 @@ def __init__(self, **kwargs): critical=True ) - self.enforce_no_new_attrs = True + self._enforce_no_new_attrs = True self.override( kwargs ) def get_process_name(self): diff --git a/improc/photometry.py b/improc/photometry.py index 3c9918e3..7f0d8ff9 100644 --- a/improc/photometry.py +++ b/improc/photometry.py @@ -224,11 +224,12 @@ def iterative_cutouts_photometry( bkg_estimate = 0.0 denominator = np.nansum(nandata - bkg_estimate) + # prevent division by zero and other rare cases epsilon = 0.01 if denominator == 0: denominator = epsilon elif abs(denominator) < epsilon: - denominator = epsilon * np.sign(denominator) # prevent division by zero and other rare cases + denominator = epsilon * np.sign(denominator) cx = np.nansum(xgrid * (nandata - bkg_estimate)) / denominator cy = np.nansum(ygrid * (nandata - bkg_estimate)) / denominator diff --git a/improc/zogy.py b/improc/zogy.py index 429d0c21..7bc56c1e 100644 --- a/improc/zogy.py +++ b/improc/zogy.py @@ -110,6 +110,8 @@ def zogy_subtract(image_ref, image_new, psf_ref, psf_new, noise_ref, noise_new, The source-noise-corrected translient score. translient_corr_sigma: numpy.ndarray The corrected translient score, converted to S/N units assuming a chi2 distribution. + zero_point: float + the flux based zero point estimate based on the input zero point and backgrounds """ if dy is None: @@ -273,6 +275,7 @@ def zogy_subtract(image_ref, image_new, psf_ref, psf_new, noise_ref, noise_new, translient_sigma=translient_sigma, translient_corr=translient_corr, translient_corr_sigma=translient_corr_sigma, + zero_point=F_D ) diff --git a/models/background.py b/models/background.py index 59803aad..d3a28463 100644 --- a/models/background.py +++ b/models/background.py @@ -384,13 +384,15 @@ def get_downstreams(self, session=None, siblings=False): from models.provenance import Provenance with SmartSession(session) as session: - subs = session.scalars( - sa.select(Image).where( - Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), - Image.upstream_images.any(Image.id == self.image_id), - ) - ).all() - output = subs + output = [] + if self.image_id is not None and self.provenance is not None: + subs = session.scalars( + sa.select(Image).where( + Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), + Image.upstream_images.any(Image.id == self.image_id), + ) + ).all() + output += subs if siblings: # There should be exactly one source list, wcs, and zp per PSF, with the same provenance diff --git a/models/base.py b/models/base.py index 614d9627..ac20b483 100644 --- a/models/base.py +++ b/models/base.py @@ -520,6 +520,15 @@ def to_json(self, filename): except: raise + def copy(self): + """Make a new instance of this object, with all column-based attributed (shallow) copied. """ + new = self.__class__() + for key in sa.inspect(self).mapper.columns.keys(): + value = getattr(self, key) + setattr(new, key, value) + + return new + Base = declarative_base(cls=SeeChangeBase) @@ -1614,6 +1623,54 @@ def find_containing( cls, siobj, session=None ): sess.execute( sa.text( "DROP TABLE temp_find_containing" ) ) return objs + @classmethod + def get_overlap_frac(cls, obj1, obj2): + """Calculate the overlap fraction between two objects that have four corners. + + Returns + ------- + overlap_frac: float + The fraction of obj1's area that is covered by the intersection of the objects + + WARNING: Right now this assumes that the images are aligned N/S and E/W. + TODO: areas of general quadrilaterals and intersections of general quadrilaterals. + + For the "image area", it uses + max(image E ra) - min(image W ra) ) * ( max(image N dec) - min( imageS dec) + (where "image E ra" refers to the corners of the image that are + on the eastern side, i.e. ra_corner_10 and ra_corner_11). This + will in general overestimate the image area, though the + overestimate will be small if the image is very close to + oriented square to the sky. + + For the "overlap area", it uses + ( min( image E ra, ref E ra ) - max( image W ra, ref W ra ) * + min( image N dec, ref N dec ) - max( image S dec, ref S dec ) ) + This will in general underestimate the overlap area, though the + underestimate will be small if both the image and reference + are oriented close to square to the sky. + + (RA ranges in all cases are scaled by cos(dec).) + + """ + dimra = (((obj1.ra_corner_10 + obj1.ra_corner_11) / 2. - + (obj1.ra_corner_00 + obj1.ra_corner_01) / 2. + ) / np.cos(obj1.dec * np.pi / 180.)) + dimdec = ((obj1.dec_corner_01 + obj1.dec_corner_11) / 2. - + (obj1.dec_corner_00 + obj1.dec_corner_10) / 2.) + r0 = max(obj2.ra_corner_00, obj2.ra_corner_01, + obj1.ra_corner_00, obj1.ra_corner_01) + r1 = min(obj2.ra_corner_10, obj2.ra_corner_10, + obj1.ra_corner_10, obj1.ra_corner_10) + d0 = max(obj2.dec_corner_00, obj2.dec_corner_10, + obj1.dec_corner_00, obj1.dec_corner_10) + d1 = min(obj2.dec_corner_01, obj2.dec_corner_11, + obj1.dec_corner_01, obj1.dec_corner_11) + dra = (r1 - r0) / np.cos((d1 + d0) / 2. * np.pi / 180.) + ddec = d1 - d0 + + return (dra * ddec) / (dimra * dimdec) + class HasBitFlagBadness: """A mixin class that adds a bitflag marking why this object is bad. """ diff --git a/models/cutouts.py b/models/cutouts.py index c0f686de..005002cd 100644 --- a/models/cutouts.py +++ b/models/cutouts.py @@ -42,6 +42,7 @@ def __getitem__(self, key): self.cutouts.load_one_co_dict(key) # no change if not found return super().__getitem__(key) + class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, HasBitFlagBadness): __tablename__ = 'cutouts' diff --git a/models/enums_and_bitflags.py b/models/enums_and_bitflags.py index c5d7a76a..ddbd077f 100644 --- a/models/enums_and_bitflags.py +++ b/models/enums_and_bitflags.py @@ -100,6 +100,20 @@ def convert( cls, value ): else: raise ValueError(f'{cls.__name__} must be integer/float key or string value, not {type(value)}') + @classmethod + def to_int(cls, value): + if isinstance(value, int): + return value + else: + return cls.convert(value) + + @classmethod + def to_string(cls, value): + if isinstance(value, str): + return value + else: + return cls.convert(value) + class FormatConverter( EnumConverter ): # This is the master format dictionary, that contains all file types for diff --git a/models/exposure.py b/models/exposure.py index fa4d2fe3..e0990897 100644 --- a/models/exposure.py +++ b/models/exposure.py @@ -33,10 +33,6 @@ from models.enums_and_bitflags import ( ImageFormatConverter, ImageTypeConverter, - image_badness_inverse, - data_badness_dict, - string_to_bitflag, - bitflag_to_string, ) # columns key names that must be loaded from the header for each Exposure @@ -49,7 +45,8 @@ 'exp_time', 'filter', 'telescope', - 'instrument' + 'instrument', + 'airmass', ] # these are header keywords that are not stored as columns of the Exposure table, @@ -256,6 +253,8 @@ def format(self, value): filter = sa.Column(sa.Text, nullable=True, index=True, doc="Name of the filter used to make this exposure. ") + airmass = sa.Column(sa.REAL, nullable=True, index=True, doc="Airmass taken from the header of the exposure. ") + @property def filter_short(self): if self.filter is None: @@ -368,7 +367,7 @@ def __init__(self, current_file=None, invent_filepath=True, **kwargs): if self.filepath is None: # in this case, the instrument must have been given if self.provenance is None: - self.make_provenance() # a default provenance for exposures + self.provenance = self.make_provenance(self.instrument) # a default provenance for exposures if invent_filepath: self.filepath = self.invent_filepath() @@ -380,7 +379,7 @@ def __init__(self, current_file=None, invent_filepath=True, **kwargs): # this can happen if the instrument is not given, but the filepath is if self.provenance is None: - self.make_provenance() # a default provenance for exposures + self.provenance = self.make_provenance(self.instrument) # a default provenance for exposures # instrument_obj is lazy loaded when first getting it if current_file is None: @@ -393,7 +392,8 @@ def __init__(self, current_file=None, invent_filepath=True, **kwargs): self.calculate_coordinates() # galactic and ecliptic coordinates - def make_provenance(self): + @classmethod + def make_provenance(cls, instrument): """Generate a Provenance for this exposure. The provenance will have only one parameter, @@ -405,13 +405,15 @@ def make_provenance(self): the upstream images when e.g., making a coadd). """ codeversion = Provenance.get_code_version() - self.provenance = Provenance( + prov = Provenance( code_version=codeversion, process='load_exposure', - parameters={'instrument': self.instrument}, + parameters={'instrument': instrument}, upstreams=[], ) + return prov + @sa.orm.reconstructor def init_on_load(self): Base.init_on_load(self) diff --git a/models/image.py b/models/image.py index 515c0e29..91ea4db6 100644 --- a/models/image.py +++ b/models/image.py @@ -6,6 +6,7 @@ import sqlalchemy as sa from sqlalchemy import orm + from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.exc import DetachedInstanceError @@ -17,7 +18,9 @@ import astropy.coordinates import astropy.units as u -from util.util import read_fits_image, save_fits_image_file +from util.util import read_fits_image, save_fits_image_file, parse_dateobs, listify +from util.radec import parse_ra_hms_to_deg, parse_dec_dms_to_deg + from models.base import ( Base, @@ -29,6 +32,7 @@ FourCorners, HasBitFlagBadness, ) +from models.provenance import Provenance from models.exposure import Exposure from models.instrument import get_instrument_instance from models.enums_and_bitflags import ( @@ -395,6 +399,13 @@ def preprocessing_done(self, value): doc='Has the sky been subtracted from this image. ' ) + airmass = sa.Column( + sa.REAL, + nullable=True, + index=True, + doc='Airmass of the observation. ' + ) + fwhm_estimate = sa.Column( sa.REAL, nullable=True, @@ -517,7 +528,7 @@ def __setattr__(self, key, value): @orm.reconstructor def init_on_load(self): - Base.init_on_load(self) + SeeChangeBase.init_on_load(self) FileOnDiskMixin.init_on_load(self) self.raw_data = None self._header = None @@ -557,9 +568,9 @@ def merge_all(self, session): if self.sources is not None: self.sources.image = new_image - self.sources.image_id = new_image.id self.sources.provenance_id = self.sources.provenance.id if self.sources.provenance is not None else None - new_image.sources = self.sources.merge_all(session=session) + new_sources = self.sources.merge_all(session=session) + new_image.sources = new_sources if new_image.sources.wcs is not None: new_image.wcs = new_image.sources.wcs @@ -675,7 +686,6 @@ def set_corners_from_header_wcs( self, wcs=None, setradec=False ): self.ecllat = sc.barycentrictrueecliptic.lat.deg self.ecllon = sc.barycentrictrueecliptic.lon.deg - @classmethod def from_exposure(cls, exposure, section_id): """ @@ -708,6 +718,7 @@ def from_exposure(cls, exposure, section_id): 'mjd', 'end_mjd', 'exp_time', + 'airmass', 'instrument', 'telescope', 'filter', @@ -833,6 +844,7 @@ def copy_image(cls, image): 'preproc_bitflag', 'astro_cal_done', 'sky_sub_done', + 'airmass', 'fwhm_estimate', 'zero_point_estimate', 'lim_mag_estimate', @@ -1041,7 +1053,7 @@ def from_new_and_ref(cls, new_image, ref_image): # get some more attributes from the new image for att in ['section_id', 'instrument', 'telescope', 'project', 'target', - 'exp_time', 'mjd', 'end_mjd', 'info', 'header', + 'exp_time', 'airmass', 'mjd', 'end_mjd', 'info', 'header', 'gallon', 'gallat', 'ecllon', 'ecllat', 'ra', 'dec', 'ra_corner_00', 'ra_corner_01', 'ra_corner_10', 'ra_corner_11', 'dec_corner_00', 'dec_corner_01', 'dec_corner_10', 'dec_corner_11' ]: @@ -1169,7 +1181,7 @@ def _get_alignment_target_image(self): def coordinates_to_alignment_target(self): """Make sure the coordinates (RA,dec, corners and WCS) all match the alignment target image. """ target = self._get_alignment_target_image() - for att in ['ra', 'dec', 'wcs', + for att in ['ra', 'dec', 'ra_corner_00', 'ra_corner_01', 'ra_corner_10', 'ra_corner_11', 'dec_corner_00', 'dec_corner_01', 'dec_corner_10', 'dec_corner_11' ]: self.__setattr__(att, getattr(target, att)) @@ -1570,6 +1582,79 @@ def free( self, free_derived_products=True, free_aligned=True, only_free=None ): for alim in self._aligned_images: alim.free( free_derived_products=free_derived_products, only_free=only_free ) + def load_products(self, provenances, session=None, must_find_all=True): + """Load the products associated with this image, using a list of provenances. + + Parameters + ---------- + provenances: single Provenance or list of Provenance objects + A list to go over, that can contain any number of Provenance objects. + Will search the database for matching objects to each provenance in turn, + and will assign them into "self" if found. + Note that it will keep the first successfully loaded product on the provenance list. + Will overwrite any existing products on the Image. + Will ignore provenances that do not match any of the products + (e.g., provenances for a different processing step). + session: SQLAlchemy session, optional + The session to use for the database queries. + If not provided, will open a session internally + and close it when the function exits. + + """ + from models.source_list import SourceList + from models.psf import PSF + from models.background import Background + from models.world_coordinates import WorldCoordinates + from models.zero_point import ZeroPoint + + if self.id is None: + raise ValueError('Cannot load products for an image without an ID!') + + provenances = listify(provenances) + if not provenances: + raise ValueError('Need at least one provenance to load products! ') + + sources = psf = bg = wcs = zp = None + with SmartSession(session) as session: + for p in provenances: + if sources is None: + sources = session.scalars( + sa.select(SourceList).where(SourceList.image_id == self.id, SourceList.provenance_id == p.id) + ).first() + if psf is None: + psf = session.scalars( + sa.select(PSF).where(PSF.image_id == self.id, PSF.provenance_id == p.id) + ).first() + if bg is None: + bg = session.scalars( + sa.select(Background).where(Background.image_id == self.id, Background.provenance_id == p.id) + ).first() + + if sources is not None: + if wcs is None: + wcs = session.scalars( + sa.select(WorldCoordinates).where( + WorldCoordinates.sources_id == sources.id, WorldCoordinates.provenance_id == p.id + ) + ).first() + if zp is None: + zp = session.scalars( + sa.select(ZeroPoint).where( + ZeroPoint.sources_id == sources.id, ZeroPoint.provenance_id == p.id + ) + ).first() + + if sources is not None: + self.sources = sources + if psf is not None: + self.psf = psf + if bg is not None: + self.bg = bg + if wcs is not None: + self.wcs = wcs + if zp is not None: + self.zp = zp + def get_upstream_provenances(self): """Collect the provenances for all upstream objects. @@ -1894,6 +1979,266 @@ def get_downstreams(self, session=None, siblings=False): return downstreams + @staticmethod + def query_images( + ra=None, + dec=None, + target=None, + section_id=None, + project=None, + instrument=None, + filter=None, + min_mjd=None, + max_mjd=None, + min_dateobs=None, + max_dateobs=None, + min_exp_time=None, + max_exp_time=None, + min_seeing=None, + max_seeing=None, + min_lim_mag=None, + max_lim_mag=None, + min_airmass=None, + max_airmass=None, + min_background=None, + max_background=None, + min_zero_point=None, + max_zero_point=None, + order_by='latest', + seeing_quality_factor=3.0, + provenance_ids=None, + type=[1, 2, 3, 4], # TODO: is there a smarter way to only get science images? + ): + """Get a SQL alchemy statement object for Image objects, with some filters applied. + + This is a convenience method to get a statement object that can be further filtered. + If no parameters are given, will happily return all images (be careful with this). + It is highly recommended to supply ra/dec to find all images overlapping with that point. + + The images are sorted either by MJD or by image quality. + Quality is defined as sum of the limiting magnitude and the seeing, + multiplied by the negative "seeing_quality_factor" parameter: + = - * + This means that as the seeing FWHM is smaller, and the limiting magnitude + is bigger (fainter) the quality is higher. + Choose a higher seeing_quality_factor to give more weight to the seeing, + and less weight to the limiting magnitude. + + Parameters + ---------- + ra: float or str (optional) + The right ascension of the target in degrees or in HMS format. + Will find all images that contain this position. + If given, must also give dec. + dec: float or str (optional) + The declination of the target in degrees or in DMS format. + Will find all images that contain this position. + If given, must also give ra. + target: str or list of strings (optional) + Find images that have this target name (e.g., field ID or Object name). + If given as a list, will match all the target names in the list. + section_id: int/str or list of ints/strings (optional) + Find images with this section ID. + If given as a list, will match all the section IDs in the list. + project: str or list of strings (optional) + Find images from this project. + If given as a list, will match all the projects in the list. + instrument: str or list of str (optional) + Find images taken using this instrument. + Provide a list to match multiple instruments. + filter: str or list of str (optional) + Find images taken using this filter. + Provide a list to match multiple filters. + min_mjd: float (optional) + Find images taken after this MJD. + max_mjd: float (optional) + Find images taken before this MJD. + min_dateobs: str (optional) + Find images taken after this date (use ISOT format or a datetime object). + max_dateobs: str (optional) + Find images taken before this date (use ISOT format or a datetime object). + min_exp_time: float (optional) + Find images with exposure time longer than this (in seconds). + max_exp_time: float (optional) + Find images with exposure time shorter than this (in seconds). + min_seeing: float (optional) + Find images with seeing FWHM larger than this (in arcsec). + max_seeing: float (optional) + Find images with seeing FWHM smaller than this (in arcsec). + min_lim_mag: float (optional) + Find images with limiting magnitude larger (fainter) than this. + max_lim_mag: float (optional) + Find images with limiting magnitude smaller (brighter) than this. + min_airmass: float (optional) + Find images with airmass larger than this. + max_airmass: float (optional) + Find images with airmass smaller than this. + min_background: float (optional) + Find images with background rms higher than this. + max_background: float (optional) + Find images with background rms lower than this. + min_zero_point: float (optional) + Find images with zero point higher than this. + max_zero_point: float (optional) + Find images with zero point lower than this. + order_by: str, default 'latest' + Sort the images by 'earliest', 'latest' or 'quality'. + The 'earliest' and 'latest' order by MJD, in ascending/descending order, respectively. + The 'quality' option will try to order the images by quality, as defined above, + with the highest quality images first. + seeing_quality_factor: float, default 3.0 + The factor to multiply the seeing FWHM by in the quality calculation. + provenance_ids: str or list of strings + Find images with these provenance IDs. + type: integer or string or list of integers or strings, default [1,2,3,4] + List of integer converted types of images to search for. + This defaults to [1,2,3,4] which corresponds to the + science images, coadds and subtractions + (see enums_and_bitflags.ImageTypeConverter for more details). + Choose 1 to get only the regular (non-coadd, non-subtraction) images. + + Returns + ------- + stmt: SQL alchemy select statement + The statement to be executed to get the images. + Do session.scalars(stmt).all() to get the images. + Additional filtering can be done on the statement before executing it. + """ + stmt = sa.select(Image) + + # filter by coordinates being contained in the image + if ra is not None and dec is not None: + if isinstance(ra, str): + ra = parse_ra_hms_to_deg(ra) + if isinstance(dec, str): + dec = parse_dec_dms_to_deg(dec) + stmt = stmt.where(Image.containing(ra, dec)) + elif ra is not None or dec is not None: + raise ValueError("Both ra and dec must be provided to search by position.") + + # filter by target (e.g., field ID, object name) and possibly section ID and/or project + targets = listify(target) + if targets is not None: + stmt = stmt.where(Image.target.in_(targets)) + section_ids = listify(section_id) + if section_ids is not None: + stmt = stmt.where(Image.section_id.in_(section_ids)) + projects = listify(project) + if projects is not None: + stmt = stmt.where(Image.project.in_(projects)) + + # filter by filter and instrument + filters = listify(filter) + if filters is not None: + stmt = stmt.where(Image.filter.in_(filters)) + instruments = listify(instrument) + if instruments is not None: + stmt = stmt.where(Image.instrument.in_(instruments)) + + # filter by MJD or dateobs + if min_mjd is not None: + if min_dateobs is not None: + raise ValueError("Cannot filter by both minimal MJD and dateobs.") + stmt = stmt.where(Image.mjd >= min_mjd) + if max_mjd is not None: + if max_dateobs is not None: + raise ValueError("Cannot filter by both maximal MJD and dateobs.") + stmt = stmt.where(Image.mjd <= max_mjd) + if min_dateobs is not None: + min_dateobs = parse_dateobs(min_dateobs, output='mjd') + stmt = stmt.where(Image.mjd >= min_dateobs) + if max_dateobs is not None: + max_dateobs = parse_dateobs(max_dateobs, output='mjd') + stmt = stmt.where(Image.mjd <= max_dateobs) + + # filter by exposure time + if min_exp_time is not None: + stmt = stmt.where(Image.exp_time >= min_exp_time) + if max_exp_time is not None: + stmt = stmt.where(Image.exp_time <= max_exp_time) + + # filter by seeing FWHM + if min_seeing is not None: + stmt = stmt.where(Image.fwhm_estimate >= min_seeing) + if max_seeing is not None: + stmt = stmt.where(Image.fwhm_estimate <= max_seeing) + + # filter by limiting magnitude + if max_lim_mag is not None: + stmt = stmt.where(Image.lim_mag_estimate <= max_lim_mag) + if min_lim_mag is not None: + stmt = stmt.where(Image.lim_mag_estimate >= min_lim_mag) + + # filter by airmass + if max_airmass is not None: + stmt = stmt.where(Image.airmass <= max_airmass) + if min_airmass is not None: + stmt = stmt.where(Image.airmass >= min_airmass) + + # filter by background + if max_background is not None: + stmt = stmt.where(Image.bkg_rms_estimate <= max_background) + if min_background is not None: + stmt = stmt.where(Image.bkg_rms_estimate >= min_background) + + # filter by zero point + if max_zero_point is not None: + stmt = stmt.where(Image.zero_point_estimate <= max_zero_point) + if min_zero_point is not None: + stmt = stmt.where(Image.zero_point_estimate >= min_zero_point) + + # filter by provenances + provenance_ids = listify(provenance_ids) + if provenance_ids is not None: + stmt = stmt.where(Image.provenance_id.in_(provenance_ids)) + + # filter by image types + types = listify(type) + if types is not None: + int_types = [ImageTypeConverter.to_int(t) for t in types] + stmt = stmt.where(Image._type.in_(int_types)) + + # sort the images + if order_by == 'earliest': + stmt = stmt.order_by(Image.mjd) + elif order_by == 'latest': + stmt = stmt.order_by(sa.desc(Image.mjd)) + elif order_by == 'quality': + stmt = stmt.order_by( + sa.desc(Image.lim_mag_estimate - abs(seeing_quality_factor) * Image.fwhm_estimate) + ) + else: + raise ValueError(f'Unknown order_by parameter: {order_by}. Use "earliest", "latest" or "quality".') + + return stmt + + @staticmethod + def get_image_from_upstreams(images, prov_id=None, session=None): + """Finds the combined image that was made from exactly the list of images (with a given provenance). """ + with SmartSession(session) as session: + association = image_upstreams_association_table + + stmt = sa.select(Image).join( + association, Image.id == association.c.downstream_id + ).group_by(Image.id).having( + sa.func.count(association.c.upstream_id) == len(images) + ) + + if prov_id is not None: # pick only those with the right provenance id + if isinstance(prov_id, Provenance): + prov_id = prov_id.id + stmt = stmt.where(Image.provenance_id == prov_id) + + output = session.scalars(stmt).all() + if len(output) > 1: + raise ValueError( + f"More than one combined image found with provenance ID {prov_id} and upstreams {images}." + ) + elif len(output) == 0: + return None + + return output[0] # should usually return one Image or None + def get_psf(self): """Load the PSF object for this image. @@ -1905,6 +2250,17 @@ def get_psf(self): return self.new_image.psf return None + def get_wcs(self): + """Load the WCS object for this image. + + If it is a sub image, it will load the WCS from the new image. + """ + if self.wcs is not None: + return self.wcs + if self.new_image is not None: + return self.new_image.wcs + return None + @property def data(self): """The underlying pixel data array (2D float array). """ diff --git a/models/instrument.py b/models/instrument.py index a8afb278..d461e4cb 100644 --- a/models/instrument.py +++ b/models/instrument.py @@ -1054,6 +1054,7 @@ def _get_header_keyword_translations(cls): instrument=['INSTRUME', 'INSTRUMENT'], telescope=['TELESCOP', 'TELESCOPE'], gain=['GAIN'], + airmass=['AIRMASS'], ) return t # TODO: add more! diff --git a/models/measurements.py b/models/measurements.py index c7127d71..80f534f6 100644 --- a/models/measurements.py +++ b/models/measurements.py @@ -562,9 +562,14 @@ def get_flux_at_point(self, ra, dec, aperture=None): mask = np.zeros_like(im, dtype=float) mask[start_y:end_y, start_x:end_x] = psf_clip[start_y + dy:end_y + dy, start_x + dx:end_x + dx] mask[np.isnan(im)] = 0 # exclude bad pixels from the mask - flux = np.nansum(im * mask) / np.nansum(mask ** 2) - fluxerr = self.bkg_std / np.sqrt(np.nansum(mask ** 2)) - area = np.nansum(mask) / (np.nansum(mask ** 2)) + + mask_sum = np.nansum(mask ** 2) + if mask_sum > 0: + flux = np.nansum(im * mask) / np.nansum(mask ** 2) + fluxerr = self.bkg_std / np.sqrt(np.nansum(mask ** 2)) + area = np.nansum(mask) / (np.nansum(mask ** 2)) + else: + flux = fluxerr = area = np.nan else: radius = self.aper_radii[aperture] # get the aperture mask diff --git a/models/provenance.py b/models/provenance.py index 75ce1d8f..c4268333 100644 --- a/models/provenance.py +++ b/models/provenance.py @@ -297,8 +297,7 @@ def __setattr__(self, key, value): super().__setattr__(key, value) def update_id(self): - """ - Update the id using the code_version, parameters and upstream_hashes. + """Update the id using the code_version, process, parameters and upstream_hashes. """ if self.process is None or self.parameters is None or self.code_version is None: raise ValueError('Provenance must have process, code_version, and parameters defined. ') @@ -313,6 +312,13 @@ def update_id(self): self.id = base64.b32encode(hashlib.sha256(json_string.encode("utf-8")).digest()).decode()[:20] + def get_combined_upstream_hash(self): + """Make a single hash from the hashes of the upstreams. + This is useful for identifying RefSets. + """ + json_string = json.dumps(self.upstream_hashes, sort_keys=True) + return base64.b32encode(hashlib.sha256(json_string.encode("utf-8")).digest()).decode()[:20] + @classmethod def get_code_version(cls, session=None): """ diff --git a/models/psf.py b/models/psf.py index 36dde2fe..98a6ee3a 100644 --- a/models/psf.py +++ b/models/psf.py @@ -541,13 +541,15 @@ def get_downstreams(self, session=None, siblings=False): from models.provenance import Provenance with SmartSession(session) as session: - subs = session.scalars( - sa.select(Image).where( - Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), - Image.upstream_images.any(Image.id == self.image_id), - ) - ).all() - output = subs + output = [] + if self.image_id is not None and self.provenance is not None: + subs = session.scalars( + sa.select(Image).where( + Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), + Image.upstream_images.any(Image.id == self.image_id), + ) + ).all() + output += subs if siblings: # There should be exactly one source list, wcs, and zp per PSF, with the same provenance diff --git a/models/reference.py b/models/reference.py index e35bf6ee..6c5f7ed2 100644 --- a/models/reference.py +++ b/models/reference.py @@ -1,20 +1,23 @@ + import sqlalchemy as sa -from sqlalchemy import orm, func +from sqlalchemy import orm from models.base import Base, AutoIDMixin, SmartSession -from models.image import Image from models.provenance import Provenance +from models.image import Image from models.source_list import SourceList from models.psf import PSF from models.background import Background from models.world_coordinates import WorldCoordinates from models.zero_point import ZeroPoint +from util.util import listify + class Reference(Base, AutoIDMixin): """ A table that refers to each reference Image object, - based on the validity time range, and the object/field it is targeting. + based on the object/field it is targeting. The provenance of this table (tagged with the "reference" process) will have as its upstream IDs the provenance IDs of the image, the source list, the PSF, the WCS, and the zero point. @@ -74,24 +77,8 @@ class Reference(Base, AutoIDMixin): doc="Section ID of the reference image. " ) - # this allows choosing a different reference for images taken before/after the validity time range - validity_start = sa.Column( - sa.DateTime, - nullable=True, - index=True, - doc="The start of the validity time range of this reference image. " - ) - - validity_end = sa.Column( - sa.DateTime, - nullable=True, - index=True, - doc="The end of the validity time range of this reference image. " - ) - # this badness is in addition to the regular bitflag of the underlying products # it can be used to manually kill a reference and replace it with another one - # even if they share the same time validity range is_bad = sa.Column( sa.Boolean, nullable=False, @@ -170,8 +157,10 @@ def init_on_load(self): if this_object_session is not None: # if just loaded, should usually have a session! self.load_upstream_products(this_object_session) - def make_provenance(self): + def make_provenance(self, parameters=None): """Make a provenance for this reference image. """ + if parameters is None: + parameters = {} upstreams = [self.image.provenance] for att in ['image', 'sources', 'psf', 'bg', 'wcs', 'zp']: if getattr(self, att) is not None: @@ -181,8 +170,8 @@ def make_provenance(self): self.provenance = Provenance( code_version=self.image.provenance.code_version, - process='reference', - parameters={}, # do we need any parameters for a reference's provenance? + process='referencing', + parameters=parameters, upstreams=upstreams, ) @@ -301,3 +290,89 @@ def merge_all(self, session): new_ref.image = self.image.merge_all(session) return new_ref + + @classmethod + def get_references( + cls, + ra=None, + dec=None, + target=None, + section_id=None, + filter=None, + skip_bad=True, + provenance_ids=None, + session=None + ): + """Find all references in the specified part of the sky, with the given filter. + Can also match specific provenances and will (by default) not return bad references. + + Parameters + ---------- + ra: float or string, optional + Right ascension in degrees, or a hexagesimal string (in hours!). + If given, must also give the declination. + dec: float or string, optional + Declination in degrees, or a hexagesimal string (in degrees). + If given, must also give the right ascension. + target: string, optional + Name of the target object or field id. + If given, must also provide the section_id. + TODO: can we relax this requirement? Issue #320 + section_id: string, optional + Section ID of the reference image. + If given, must also provide the target. + filter: string, optional + Filter of the reference image. + If not given, will return references with any filter. + provenance_ids: list of strings or Provenance objects, optional + List of provenance IDs to match. + The references must have a provenance with one of these IDs. + If not given, will load all matching references with any provenance. + skip_bad: bool + Whether to skip bad references. Default is True. + session: Session, optional + The database session to use. + If not given, will open a session and close it at end of function. + + """ + if target is not None and section_id is not None: + if ra is not None or dec is not None: + raise ValueError('Cannot provide target/section_id and also ra/dec! ') + stmt = sa.select(cls).where( + cls.target == target, + cls.section_id == str(section_id), + ) + elif target is not None or section_id is not None: + raise ValueError("Must provide both target and section_id, or neither.") + + if ra is not None and dec is not None: + stmt = sa.select(cls).where( + cls.image.has(Image.containing(ra, dec)) + ) + elif ra is not None or dec is not None: + raise ValueError("Must provide both ra and dec, or neither.") + + if ra is None and target is None: # the above also implies the dec and section_id are also missing + raise ValueError("Must provide either ra and dec, or target and section_id.") + + if filter is not None: + stmt = stmt.where(cls.filter == filter) + + if skip_bad: + stmt = stmt.where(cls.is_bad.is_(False)) + + provenance_ids = listify(provenance_ids) + + if provenance_ids is not None: + for i, prov in enumerate(provenance_ids): + if isinstance(prov, Provenance): + provenance_ids[i] = prov.id + elif not isinstance(prov, str): + raise ValueError(f"Provenance ID must be a string or a Provenance object, not {type(prov)}.") + + stmt = stmt.where(cls.provenance_id.in_(provenance_ids)) + + with SmartSession(session) as session: + return session.scalars(stmt).all() + + diff --git a/models/refset.py b/models/refset.py new file mode 100644 index 00000000..1c876712 --- /dev/null +++ b/models/refset.py @@ -0,0 +1,69 @@ +import sqlalchemy as sa +from sqlalchemy import orm + +from models.base import Base, SeeChangeBase, AutoIDMixin, SmartSession +from models.provenance import Provenance + + +# provenance to refset association table: +refset_provenance_association_table = sa.Table( + 'refset_provenance_association', + Base.metadata, + sa.Column('provenance_id', + sa.Text, + sa.ForeignKey( + 'provenances.id', ondelete="CASCADE", name='refset_provenances_association_provenance_id_fkey' + ), + primary_key=True), + sa.Column('refset_id', + sa.Integer, + sa.ForeignKey('refsets.id', ondelete="CASCADE", name='refsets_provenances_association_refset_id_fkey'), + primary_key=True), +) + + +class RefSet(Base, AutoIDMixin): + __tablename__ = 'refsets' + + name = sa.Column( + sa.Text, + nullable=False, + index=True, + unique=True, + doc="Name of the reference set. " + ) + + description = sa.Column( + sa.Text, + nullable=True, + doc="Description of the reference set. " + ) + + upstream_hash = sa.Column( + sa.Text, + nullable=False, + index=True, + doc="Hash of the upstreams used to make the reference provenance. " + ) + + provenances = orm.relationship( + Provenance, + secondary=refset_provenance_association_table, + backref='refsets', # add refsets attribute to Provenance + order_by=Provenance.created_at, + cascade='all' + ) + + def __init__(self, **kwargs): + SeeChangeBase.__init__(self) # don't pass kwargs as they could contain non-column key-values + + # manually set all properties (columns or not) + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + + @orm.reconstructor + def init_on_load(self): + SeeChangeBase.init_on_load(self) + + diff --git a/models/report.py b/models/report.py index ead59d29..a3c7cc49 100644 --- a/models/report.py +++ b/models/report.py @@ -1,3 +1,4 @@ +import time import sqlalchemy as sa from sqlalchemy import orm @@ -313,7 +314,15 @@ def scan_datastore(self, ds, process_step=None, session=None): The session to use for committing the changes to the database. If not given, will open a session and close it at the end of the function. + + NOTE: it may be better not to provide the external session + to this function. That way it will only commit this report, + and not also save other objects that were pending on the session. """ + t0 = time.perf_counter() + if 'reporting' not in self.process_runtime: + self.process_runtime['reporting'] = 0.0 + # parse the error, if it exists, so we can get to other data products without raising exception = ds.read_exception() @@ -325,8 +334,8 @@ def scan_datastore(self, ds, process_step=None, session=None): self.products_committed = ds.products_committed # store the runtime and memory usage statistics - self.process_runtime = ds.runtimes # update with new dictionary - self.process_memory = ds.memory_usages # update with new dictionary + self.process_runtime.update(ds.runtimes) # update with new dictionary + self.process_memory.update(ds.memory_usages) # update with new dictionary if process_step is not None: # append the newest step to the progress bitflag @@ -349,6 +358,8 @@ def scan_datastore(self, ds, process_step=None, session=None): with SmartSession(session) as session: new_report = self.commit_to_database(session=session) + self.process_runtime['reporting'] += time.perf_counter() - t0 + if exception is not None: raise exception diff --git a/models/source_list.py b/models/source_list.py index 26819f47..45146503 100644 --- a/models/source_list.py +++ b/models/source_list.py @@ -762,13 +762,15 @@ def get_downstreams(self, session=None, siblings=False): from models.provenance import Provenance with SmartSession(session) as session: - subs = session.scalars( - sa.select(Image).where( - Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), - Image.upstream_images.any(Image.id == self.image_id), - ) - ).all() - output = subs + output = [] + if self.image_id is not None and self.provenance is not None: + subs = session.scalars( + sa.select(Image).where( + Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), + Image.upstream_images.any(Image.id == self.image_id), + ) + ).all() + output += subs if self.is_sub: cutouts = session.scalars(sa.select(Cutouts).where(Cutouts.sources_id == self.id)).all() diff --git a/models/world_coordinates.py b/models/world_coordinates.py index 7720a41e..56406398 100644 --- a/models/world_coordinates.py +++ b/models/world_coordinates.py @@ -117,13 +117,15 @@ def get_downstreams(self, session=None, siblings=False): from models.provenance import Provenance with (SmartSession(session) as session): - subs = session.scalars( - sa.select(Image).where( - Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), - Image.upstream_images.any(Image.id == self.sources.image_id), - ) - ).all() - output = subs + output = [] + if self.provenance is not None: + subs = session.scalars( + sa.select(Image).where( + Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)), + Image.upstream_images.any(Image.id == self.sources.image_id), + ) + ).all() + output += subs if siblings: sources = session.scalars(sa.select(SourceList).where(SourceList.id == self.sources_id)).all() diff --git a/models/zero_point.py b/models/zero_point.py index 0e8bdbcc..fc8639a1 100644 --- a/models/zero_point.py +++ b/models/zero_point.py @@ -156,12 +156,14 @@ def get_downstreams(self, session=None, siblings=False): from models.provenance import Provenance with SmartSession(session) as session: - subs = session.scalars( - sa.select(Image).where( - Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)) - ) - ).all() - output = subs + output = [] + if self.provenance is not None: + subs = session.scalars( + sa.select(Image).where( + Image.provenance.has(Provenance.upstreams.any(Provenance.id == self.provenance.id)) + ) + ).all() + output += subs if siblings: sources = session.scalars(sa.select(SourceList).where(SourceList.id == self.sources_id)).all() diff --git a/pipeline/astro_cal.py b/pipeline/astro_cal.py index 0f74a1b4..47cf106e 100644 --- a/pipeline/astro_cal.py +++ b/pipeline/astro_cal.py @@ -1,4 +1,3 @@ -import os import time import pathlib @@ -6,7 +5,7 @@ from util.exceptions import CatalogNotFoundError, SubprocessFailure, BadMatchException from util.logger import SCLogger -from util.util import parse_bool +from util.util import env_as_bool from models.catalog_excerpt import CatalogExcerpt from models.world_coordinates import WorldCoordinates @@ -123,7 +122,7 @@ def __init__(self, **kwargs): 300, int, 'Timeout in seconds for scamp to run', - critical=True + critical=False ) self._enforce_no_new_attrs = True @@ -133,6 +132,9 @@ def __init__(self, **kwargs): def get_process_name(self): return 'astro_cal' + def require_siblings(self): + return True + class AstroCalibrator: def __init__(self, **kwargs): @@ -285,7 +287,7 @@ def run(self, *args, **kwargs): try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -311,16 +313,6 @@ def run(self, *args, **kwargs): else: raise ValueError( f'Unknown solution method {self.pars.solution_method}' ) - # update the upstream bitflag - sources = ds.get_sources( session=session ) - if sources is None: - raise ValueError( - f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}' - ) - if ds.wcs._upstream_bitflag is None: - ds.wcs._upstream_bitflag = 0 - ds.wcs._upstream_bitflag |= sources.bitflag - # If an astro cal wasn't previously run on this image, # update the image's ra/dec and corners attributes based on this new wcs if not image.astro_cal_done: @@ -328,10 +320,26 @@ def run(self, *args, **kwargs): image.astro_cal_done = True ds.runtimes['astro_cal'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['astro_cal'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB + # update the bitflag with the upstreams + sources = ds.get_sources(session=session) + if sources is None: + raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') + psf = ds.get_psf(session=session) + if psf is None: + raise ValueError(f'Cannot find a PSF corresponding to the datastore inputs: {ds.get_inputs()}') + bg = ds.get_background(session=session) + if bg is None: + raise ValueError(f'Cannot find a background corresponding to the datastore inputs: {ds.get_inputs()}') + + ds.wcs._upstream_bitflag = 0 + ds.wcs._upstream_bitflag |= sources.bitflag # includes badness from Image as well + ds.wcs._upstream_bitflag |= psf.bitflag + ds.wcs._upstream_bitflag |= bg.bitflag + except Exception as e: ds.catch_exception(e) finally: diff --git a/pipeline/backgrounding.py b/pipeline/backgrounding.py index 590dcef3..c8c4c934 100644 --- a/pipeline/backgrounding.py +++ b/pipeline/backgrounding.py @@ -1,4 +1,3 @@ -import os import time import numpy as np @@ -11,7 +10,7 @@ from models.background import Background from util.logger import SCLogger -from util.util import parse_bool +from util.util import env_as_bool class ParsBackgrounder(Parameters): @@ -65,6 +64,9 @@ def __init__(self, **kwargs): def get_process_name(self): return 'backgrounding' + def require_siblings(self): + return True + class Backgrounder: def __init__(self, **kwargs): @@ -89,7 +91,7 @@ def run(self, *args, **kwargs): try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -145,21 +147,23 @@ def run(self, *args, **kwargs): ds.image.bkg_mean_estimate = float( bg.value ) ds.image.bkg_rms_estimate = float( bg.noise ) - bg._upstream_bitflag = 0 - bg._upstream_bitflag |= ds.image.bitflag - sources = ds.get_sources(session=session) - if sources is not None: - bg._upstream_bitflag |= sources.bitflag - + if sources is None: + raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') psf = ds.get_psf(session=session) - if psf is not None: - bg._upstream_bitflag |= psf.bitflag + if psf is None: + raise ValueError(f'Cannot find a PSF corresponding to the datastore inputs: {ds.get_inputs()}') + + bg._upstream_bitflag = 0 + bg._upstream_bitflag |= ds.image.bitflag + bg._upstream_bitflag |= sources.bitflag + bg._upstream_bitflag |= psf.bitflag ds.bg = bg ds.runtimes['backgrounding'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['backgrounding'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB diff --git a/pipeline/coaddition.py b/pipeline/coaddition.py index aec26ed7..acc44bdc 100644 --- a/pipeline/coaddition.py +++ b/pipeline/coaddition.py @@ -1,8 +1,6 @@ import numpy as np from numpy.fft import fft2, ifft2, fftshift -import sqlalchemy as sa - from astropy.time import Time from sep import Background @@ -18,7 +16,6 @@ from pipeline.astro_cal import AstroCalibrator from pipeline.photo_cal import PhotCalibrator from util.util import get_latest_provenance, parse_session -from util.radec import parse_ra_hms_to_deg, parse_dec_dms_to_deg from improc.bitmask_tools import dilate_bitflag from improc.inpainting import Inpainter @@ -74,7 +71,7 @@ def __init__(self, **kwargs): critical=True, ) - self.enforce_no_new_attrs = True + self._enforce_no_new_attrs = True self.override( kwargs ) def get_process_name(self): @@ -585,14 +582,9 @@ def parse_inputs(self, *args, **kwargs): raise ValueError('All unnamed arguments must be Image objects. ') if self.images is None: # get the images from the DB - # TODO: this feels like it could be a useful tool, maybe need to move it to Image class? Issue 188 # if no images were given, parse the named parameters ra = kwargs.get('ra', None) - if isinstance(ra, str): - ra = parse_ra_hms_to_deg(ra) dec = kwargs.get('dec', None) - if isinstance(dec, str): - dec = parse_dec_dms_to_deg(dec) target = kwargs.get('target', None) if target is None and (ra is None or dec is None): raise ValueError('Must give either target or RA and Dec. ') @@ -604,14 +596,10 @@ def parse_inputs(self, *args, **kwargs): if start_time is None: start_time = end_time - self.pars.date_range - if isinstance(end_time, str): - end_time = Time(end_time).mjd - if isinstance(start_time, str): - start_time = Time(start_time).mjd - instrument = kwargs.get('instrument', None) filter = kwargs.get('filter', None) section_id = str(kwargs.get('section_id', None)) + provenance_ids = kwargs.get('provenance_ids', None) if provenance_ids is None: prov = get_latest_provenance('preprocessing', session=session) @@ -619,19 +607,17 @@ def parse_inputs(self, *args, **kwargs): provenance_ids = listify(provenance_ids) with SmartSession(session) as dbsession: - stmt = sa.select(Image).where( - Image.mjd >= start_time, - Image.mjd <= end_time, - Image.instrument == instrument, - Image.filter == filter, - Image.section_id == section_id, - Image.provenance_id.in_(provenance_ids), - ) - - if target is not None: - stmt = stmt.where(Image.target == target) - else: - stmt = stmt.where(Image.containing( ra, dec )) + stmt = Image.query_images( + ra=ra, + dec=dec, + target=target, + section_id=section_id, + instrument=instrument, + filter=filter, + min_dateobs=start_time, + max_dateobs=end_time, + provenance_ids=provenance_ids + ) self.images = dbsession.scalars(stmt.order_by(Image.mjd.asc())).all() return session @@ -641,11 +627,34 @@ def run(self, *args, **kwargs): if self.images is None or len(self.images) == 0: raise ValueError('No images found matching the given parameters. ') + # use the images and their source lists to get a list of provenances and code versions + coadd_upstreams = set() + code_versions = set() + # assumes each image given to the coaddition pipline has sources loaded + for im in self.images: + coadd_upstreams.add(im.provenance) + coadd_upstreams.add(im.sources.provenance) + code_versions.add(im.provenance.code_version) + code_versions.add(im.sources.provenance.code_version) + + code_versions = list(code_versions) + code_versions.sort(key=lambda x: x.id) + code_version = code_versions[-1] # choose the most recent ID if there are multiple code versions + coadd_upstreams = list(coadd_upstreams) + self.datastore = DataStore() - self.datastore.prov_tree = self.make_provenance_tree(session=session) + self.datastore.prov_tree = self.make_provenance_tree(coadd_upstreams, code_version, session=session) + + # check if this exact coadd image already exists in the DB + with SmartSession(session) as dbsession: + coadd_prov = self.datastore.prov_tree['coaddition'] + coadd_image = Image.get_image_from_upstreams(self.images, coadd_prov, session=dbsession) - # the self.aligned_images is None unless you explicitly pass in the pre-aligned images to save time - self.datastore.image = self.coadder.run(self.images, self.aligned_images) + if coadd_image is not None: + self.datastore.image = coadd_image + else: + # the self.aligned_images is None unless you explicitly pass in the pre-aligned images to save time + self.datastore.image = self.coadder.run(self.images, self.aligned_images) # TODO: add the warnings/exception capturing, runtime/memory tracking (and Report making) as in top_level.py self.datastore = self.extractor.run(self.datastore) @@ -655,27 +664,15 @@ def run(self, *args, **kwargs): return self.datastore.image - def make_provenance_tree(self, session=None): + def make_provenance_tree(self, coadd_upstreams, code_version, session=None): """Make a (short) provenance tree to use when fetching the provenances of upstreams. """ with SmartSession(session) as session: - coadd_upstreams = set() - code_versions = set() - # assumes each image given to the coaddition pipline has sources, psf, background, wcs, zp, all loaded - for im in self.images: - coadd_upstreams.add(im.provenance) - coadd_upstreams.add(im.sources.provenance) - code_versions.add(im.provenance.code_version) - code_versions.add(im.sources.provenance.code_version) - - code_versions = list(code_versions) - code_versions.sort(key=lambda x: x.id) - code_version = code_versions[-1] # choose the most recent ID if there are multiple code versions pars_dict = self.coadder.pars.get_critical_pars() coadd_prov = Provenance( code_version=code_version, process='coaddition', - upstreams=list(coadd_upstreams), + upstreams=coadd_upstreams, parameters=pars_dict, is_testing="test_parameter" in pars_dict, # this is a flag for testing purposes ) @@ -694,3 +691,19 @@ def make_provenance_tree(self, session=None): return {'coaddition': coadd_prov, 'extraction': extract_prov} + def override_parameters(self, **kwargs): + """Override the parameters of this pipeline and its sub objects. """ + from pipeline.top_level import PROCESS_OBJECTS + + for key, value in kwargs.items(): + if key in PROCESS_OBJECTS: + if isinstance(PROCESS_OBJECTS[key], dict): + for sub_key, sub_value in PROCESS_OBJECTS[key].items(): + if sub_key in value: + getattr(self, sub_value).pars.override(value[sub_key]) + elif isinstance(PROCESS_OBJECTS[key], str): + getattr(self, PROCESS_OBJECTS[key]).pars.override(value) + elif key == 'coaddition': + self.coadder.pars.override(value) + else: + self.pars.override({key: value}) diff --git a/pipeline/cutting.py b/pipeline/cutting.py index b98affeb..8dfa04f9 100644 --- a/pipeline/cutting.py +++ b/pipeline/cutting.py @@ -1,4 +1,3 @@ -import os import time from improc.tools import make_cutouts @@ -9,7 +8,7 @@ from pipeline.parameters import Parameters from pipeline.data_store import DataStore -from util.util import parse_session, parse_bool +from util.util import parse_session, env_as_bool class ParsCutter(Parameters): @@ -63,7 +62,7 @@ def run(self, *args, **kwargs): try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -72,24 +71,21 @@ def run(self, *args, **kwargs): # get the provenance for this step: prov = ds.get_provenance('cutting', self.pars.get_critical_pars(), session=session) - # try to find some measurements in memory or in the database: + detections = ds.get_detections(session=session) + if detections is None: + raise ValueError( + f'Cannot find a detections source list corresponding to the datastore inputs: {ds.get_inputs()}' + ) + + # try to find some cutouts in memory or in the database: cutouts = ds.get_cutouts(prov, session=session) + if cutouts is not None: cutouts.load_all_co_data() if cutouts is None or len(cutouts.co_dict) == 0: - self.has_recalculated = True - # use the latest source list in the data store, - # or load using the provenance given in the - # data store's upstream_provs, or just use - # the most recent provenance for "detection" - detections = ds.get_detections(session=session) - - if detections is None: - raise ValueError( - f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}' - ) + # find detections in order to get the cutouts x = detections.x y = detections.y @@ -116,9 +112,6 @@ def run(self, *args, **kwargs): cutouts = Cutouts.from_detections(detections, provenance=prov) - cutouts._upstream_bitflag = 0 - cutouts._upstream_bitflag |= detections.bitflag - for i, source in enumerate(detections.data): data_dict = {} data_dict["sub_data"] = sub_stamps_data[i] @@ -138,6 +131,9 @@ def run(self, *args, **kwargs): data_dict["new_flags"] = new_stamps_flags[i] cutouts.co_dict[f"source_index_{i}"] = data_dict + # regardless of whether we loaded or calculated the cutouts, we need to update the bitflag + cutouts._upstream_bitflag = 0 + cutouts._upstream_bitflag |= detections.bitflag # add the resulting Cutouts to the data store if cutouts.provenance is None: @@ -152,7 +148,7 @@ def run(self, *args, **kwargs): ds.cutouts = cutouts ds.runtimes['cutting'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['cutting'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB diff --git a/pipeline/data_store.py b/pipeline/data_store.py index 62b5c9b1..4b3043a2 100644 --- a/pipeline/data_store.py +++ b/pipeline/data_store.py @@ -1,12 +1,11 @@ import warnings -import math import datetime import sqlalchemy as sa -from util.util import parse_session +from util.util import parse_session, listify from util.logger import SCLogger -from models.base import SmartSession, FileOnDiskMixin +from models.base import SmartSession, FileOnDiskMixin, FourCorners from models.provenance import CodeVersion, Provenance from models.exposure import Exposure from models.image import Image, image_upstreams_association_table @@ -24,7 +23,7 @@ 'exposure': [], # no upstreams 'preprocessing': ['exposure'], 'extraction': ['preprocessing'], - 'subtraction': ['reference', 'preprocessing', 'extraction'], + 'subtraction': ['referencing', 'preprocessing', 'extraction'], 'detection': ['subtraction'], 'cutting': ['detection'], 'measuring': ['cutting'], @@ -36,8 +35,8 @@ 'exposure': 'exposure', 'preprocessing': 'image', 'coaddition': 'image', - 'extraction': ['sources', 'psf', 'background', 'wcs', 'zp'], - 'reference': 'reference', + 'extraction': ['sources', 'psf', 'bg', 'wcs', 'zp'], + 'referencing': 'reference', 'subtraction': 'sub_image', 'detection': 'detections', 'cutting': 'cutouts', @@ -219,7 +218,7 @@ def parse_args(self, *args, **kwargs): self.image = val if self.image is not None: - for att in ['sources', 'psf', 'wcs', 'zp', 'detections', 'cutouts', 'measurements']: + for att in ['sources', 'psf', 'bg', 'wcs', 'zp', 'detections', 'cutouts', 'measurements']: if getattr(self.image, att, None) is not None: setattr(self, att, getattr(self.image, att)) @@ -255,7 +254,8 @@ def read_exception(self): def reraise(self): """If an exception is logged to the datastore, raise it. Otherwise pass. """ if self.exception is not None: - raise self.exception + e = self.read_exception() + raise e def __init__(self, *args, **kwargs): """ @@ -518,8 +518,8 @@ def get_provenance(self, process, pars_dict, session=None): if len(upstreams) != len(UPSTREAM_STEPS[process]): raise ValueError(f'Could not find all upstream provenances for process {process}.') - for u in upstreams: # check if "reference" is in the list, if so, replace it with its upstreams - if u.process == 'reference': + for u in upstreams: # check if "referencing" is in the list, if so, replace it with its upstreams + if u.process == 'referencing': upstreams.remove(u) for up in u.upstreams: upstreams.append(up) @@ -652,7 +652,8 @@ def get_image(self, provenance=None, session=None): self.image = None if self.exposure is not None and self.image.exposure_id != self.exposure.id: self.image = None - if self.section is not None and str(self.image.section_id) != self.section.identifier: + if ( self.section is not None and self.image is not None and + str(self.image.section_id) != self.section.identifier ): self.image = None if self.image is not None and provenance is not None and self.image.provenance.id != provenance.id: self.image = None @@ -733,7 +734,7 @@ def get_sources(self, provenance=None, session=None): if self.sources is None: with SmartSession(session, self.session) as session: image = self.get_image(session=session) - if image is not None: + if image is not None and provenance is not None: self.sources = session.scalars( sa.select(SourceList).where( SourceList.image_id == image.id, @@ -959,92 +960,36 @@ def get_zp(self, provenance=None, session=None): return self.zp - @classmethod - def _overlap_frac(cls, image, refim): - """Calculate the overlap fraction between image and refim. - - Parameters - ---------- - image: Image - The search image; we want to find a reference image that has - maximum overlap with this image. - - refim: Image - The reference image to check. - - Returns - ------- - overlap_frac: float - The fraction of image's area that is covered by the - intersection of image and refim. - - WARNING: Right now this assumes that the images are aligned N/S and - E/W. TODO: areas of general quadrilaterals and interssections of - general quadrilaterals. - - For the "image area", it uses - max(image E ra) - min(image W ra) ) * ( max(image N dec) - min( imageS dec) - (where "image E ra" refers to the corners of the image that are - on the eastern side, i.e. ra_corner_10 and ra_corner_11). This - will in general overestimate the image area, though the - overestimate will be small if the image is very close to - oriented square to the sky. - - For the "overlap area", it uses - ( min( image E ra, ref E ra ) - max( image W ra, ref W ra ) * - min( image N dec, ref N dec ) - max( image S dec, ref S dec ) ) - This will in general underestimate the overlap area, though the - underestimate will be small if both the image and reference - are oriented close to square to the sky. - - (RA ranges in all cases are scaled by cos(dec).) - - """ - - dimra = (((image.ra_corner_10 + image.ra_corner_11) / 2. - - (image.ra_corner_00 + image.ra_corner_01) / 2. - ) / math.cos(image.dec * math.pi / 180.)) - dimdec = ((image.dec_corner_01 + image.dec_corner_11) / 2. - - (image.dec_corner_00 + image.dec_corner_10) / 2.) - r0 = max(refim.ra_corner_00, refim.ra_corner_01, - image.ra_corner_00, image.ra_corner_01) - r1 = min(refim.ra_corner_10, refim.ra_corner_10, - image.ra_corner_10, image.ra_corner_10) - d0 = max(refim.dec_corner_00, refim.dec_corner_10, - image.dec_corner_00, image.dec_corner_10) - d1 = min(refim.dec_corner_01, refim.dec_corner_11, - image.dec_corner_01, image.dec_corner_11) - dra = (r1 - r0) / math.cos((d1 + d0) / 2. * math.pi / 180.) - ddec = d1 - d0 - - return (dra * ddec) / (dimra * dimdec) - - def get_reference(self, minovfrac=0.85, must_match_instrument=True, must_match_filter=True, - must_match_target=False, must_match_section=False, session=None ): + def get_reference(self, provenances=None, min_overlap=0.85, match_filter=True, + ignore_target_and_section=False, skip_bad=True, session=None ): """Get the reference for this image. Parameters ---------- - minovfrac: float, default 0.85 + provenances: list of provenance objects + A list of provenances to use to identify a reference. + Will check for existing references for each one of these provenances, + and will apply any additional criteria to each resulting reference, in turn, + until the first one qualifies and is the one returned + (i.e, it is possible to take the reference matching the first provenance + and never load the others). + If not given, will try to get the provenances from the prov_tree attribute. + If those are not given, or if no qualifying reference is found, will return None. + min_overlap: float, default 0.85 Area of overlap region must be at least this fraction of the area of the search image for the reference to be good. (Warning: calculation implicitly assumes that images are aligned N/S and E/W.) Make this <= 0 to not consider overlap fraction when finding a reference. - must_match_instrument: bool, default True - If True, only find a reference from the same instrument - as that of the DataStore's image. - must_match_filter: bool, default True + match_filter: bool, default True If True, only find a reference whose filter matches the DataStore's images' filter. - must_match_target: bool, default False - If True, only find a reference if the "target" field of the - reference image matches the "target" field of the image in - the DataStore. - must_match_section: bool, default False - If True, only find a reference if the "section_id" field of - the reference image matches that of the image in the - Datastore. + ignore_target_and_section: bool, default False + If False, will try to match based on the datastore image's target and + section_id parameters (if they are not None) and only use RA/dec to match + if they are missing. If True, will only use RA/dec to match. + skip_bad: bool, default True + If True, will skip references that are marked as bad. session: sqlalchemy.orm.session.Session or SmartSession An optional session to use for the database query. If not given, will use the session stored inside the @@ -1056,78 +1001,83 @@ def get_reference(self, minovfrac=0.85, must_match_instrument=True, must_match_f ref: Image object The reference image for this image, or None if no reference is found. - It will only return references whose validity date range - includes DataStore.image.observation_time. - - If minovfrac is given, it will return the reference that has the - highest ovfrac. (If, by unlikely chance, more than one have + If min_overlap is given, it will return the reference that has the + highest overlap fraction. (If, by unlikely chance, more than one have identical overlap fractions, an undeterministically chosen reference will be returned. Ideally, by construction, you will never have this situation in your database; you will only have a single valid reference image for a given instrument/filter/date that has an appreciable overlap with any possible image from that instrument. The software does not enforce this, however.) + """ + image = self.get_image(session=session) + if image is None: + return None # cannot find a reference without a new image to match - If minovfrac is not given, it will return the first reference found - that matches the other criteria. Be careful with this. + if provenances is None: # try to get it from the prov_tree + provenances = self._get_provenance_for_an_upstream('referencing') - """ - with SmartSession(session, self.session) as session: - image = self.get_image(session=session) - - if self.reference is not None: - ovfrac = self._overlap_frac( image, self.reference.image ) if minovfrac > 0. else 0. - if not ( - ( self.reference.validity_start is None or - self.reference.validity_start <= image.observation_time ) and - ( self.reference.validity_end is None or - self.reference.validity_end >= image.observation_time ) and - ( ( not must_match_instrument ) or ( self.reference.image.instrument - == image.instrument ) ) and - ( ( not must_match_filter) or ( self.reference.filter == image.filter ) ) and - ( ( not must_match_target ) or ( self.image.target == self.reference.target ) ) and - ( ( not must_match_section ) or ( self.image.section_id == self.reference.section_id ) ) and - ( self.reference.is_bad is False ) and - ( ( minovfrac <= 0. ) or ( ovfrac > minovfrac ) ) - ): - self.reference = None - - if self.reference is None: - q = ( session.query( Reference, Image ) - .filter( Reference.image_id == Image.id ) - .filter( sa.or_( Reference.validity_start.is_(None), - Reference.validity_start <= image.observation_time ) ) - .filter( sa.or_( Reference.validity_end.is_(None), - Reference.validity_end >= image.observation_time ) ) - .filter( Reference.is_bad.is_(False ) ) - ) - if minovfrac > 0.: - q = ( q - .filter( Image.ra >= min( image.ra_corner_00, image.ra_corner_01 ) ) - .filter( Image.ra <= max( image.ra_corner_10, image.ra_corner_11 ) ) - .filter( Image.dec >= min( image.dec_corner_00, image.dec_corner_10 ) ) - .filter( Image.dec <= max( image.dec_corner_01, image.dec_corner_11 ) ) - ) - if must_match_instrument: q = q.filter( Image.instrument == image.instrument ) - if must_match_filter: q = q.filter( Reference.filter == image.filter ) - if must_match_target: q = q.filter( Reference.target == image.target ) - if must_match_section: q = q.filter( Reference.section_id == image.section_id ) - - ref = None - if minovfrac <= 0.: - ref, refim = q.first() - else: - maxov = minovfrac - for curref, currefim in q.all(): - ovfrac = self._overlap_frac( image, currefim ) - if ovfrac > maxov: - maxov = ovfrac - ref = curref + provenances = listify(provenances) + + if provenances is None: + self.reference = None # cannot get a reference without any associated provenances + + # first, some checks to see if existing reference is ok + if self.reference is not None and provenances is not None: # check for a mismatch of reference to provenances + if self.reference.provenance_id not in [p.id for p in provenances]: + self.reference = None - if ref is None: - raise ValueError(f'No reference image found for image {image.id}') + if self.reference is not None and min_overlap is not None and min_overlap > 0: + ovfrac = FourCorners.get_overlap_frac(image, self.reference.image) + if ovfrac < min_overlap: + self.reference = None - self.reference = curref + if self.reference is not None and skip_bad: + if self.reference.is_bad: + self.reference = None + + if self.reference is not None and match_filter: + if self.reference.filter != image.filter: + self.reference = None + + if ( + self.reference is not None and not ignore_target_and_section and + image.target is not None and image.section_id is not None + ): + if self.reference.target != image.target or self.reference.section_id != image.section_id: + self.reference = None + + # if we have survived this long without losing the reference, can return it here: + if self.reference is not None: + return self.reference + + # No reference was found (or it didn't match other parameters) must find a new one + with SmartSession(session, self.session) as session: + if ignore_target_and_section or image.target is None or image.section_id is None: + arguments = dict(ra=image.ra, dec=image.dec) + else: + arguments = dict(target=image.target, section_id=image.section_id) + + if match_filter: + arguments['filter'] = image.filter + else: + arguments['filter'] = None + + arguments['skip_bad'] = skip_bad + arguments['provenance_ids'] = provenances + references = Reference.get_references(**arguments, session=session) + + self.reference = None + for ref in references: + if min_overlap is not None and min_overlap > 0: + ovfrac = FourCorners.get_overlap_frac(image, ref.image) + # print( + # f'ref.id= {ref.id}, ra_left= {ref.image.ra_corner_00:.2f}, ' + # f'ra_right= {ref.image.ra_corner_11:.2f}, ovfrac= {ovfrac}' + # ) + if ovfrac >= min_overlap: + self.reference = ref + break return self.reference @@ -1527,7 +1477,7 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False, self.image = self.image.merge_all(session) for att in ['sources', 'psf', 'bg', 'wcs', 'zp']: setattr(self, att, None) # avoid automatically appending to the image self's non-merged products - for att in ['exposure', 'sources', 'psf', 'wcs', 'zp']: + for att in ['exposure', 'sources', 'psf', 'bg', 'wcs', 'zp']: if getattr(self.image, att, None) is not None: setattr(self, att, getattr(self.image, att)) diff --git a/pipeline/detection.py b/pipeline/detection.py index 85307c7e..b8dd3c49 100644 --- a/pipeline/detection.py +++ b/pipeline/detection.py @@ -1,4 +1,3 @@ -import os import pathlib import random import subprocess @@ -17,7 +16,7 @@ from util.config import Config from util.logger import SCLogger -from util.util import parse_bool +from util.util import env_as_bool from pipeline.parameters import Parameters from pipeline.data_store import DataStore @@ -26,7 +25,6 @@ from models.image import Image from models.source_list import SourceList from models.psf import PSF -from models.background import Background from improc.tools import sigma_clipping @@ -49,46 +47,6 @@ def __init__(self, **kwargs): critical=True ) - self.background_format = self.add_par( - 'background_format', - 'map', - str, - 'Format of the background; one of "map", "scalar", or "polynomial".', - critical=True - ) - - self.background_order = self.add_par( - 'background_order', - 2, - int, - 'Order of the polynomial background. Ignored unless background is "polynomial".', - critical=True - ) - - self.background_method = self.add_par( - 'background_method', - 'sep', - str, - 'Method to use for background subtraction; currently only "sep" is supported.', - critical=True - ) - - self.background_box_size = self.add_par( - 'background_box_size', - 128, - int, - 'Size of the box to use for background estimation in sep.', - critical=True - ) - - self.background_filt_size = self.add_par( - 'background_filt_size', - 3, - int, - 'Size of the filter to use for background estimation in sep.', - critical=True - ) - self.apers = self.add_par( 'apers', [1.0, 2.0, 3.0, 5.0], @@ -164,6 +122,12 @@ def __init__(self, **kwargs): def get_process_name(self): return 'detection' + def require_siblings(self): + if self.pars.subtraction: + return False + else: + return True + class Detector: """Extract sources (and possibly a psf) from images or subtraction images. @@ -262,7 +226,7 @@ def run(self, *args, **kwargs): if self.pars.subtraction: try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -309,7 +273,7 @@ def run(self, *args, **kwargs): ds.detections = detections ds.runtimes['detection'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['detection'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB @@ -322,7 +286,7 @@ def run(self, *args, **kwargs): prov = ds.get_provenance('extraction', self.pars.get_critical_pars(), session=session) try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -368,7 +332,7 @@ def run(self, *args, **kwargs): ds.image.fwhm_estimate = psf.fwhm_pixels # TODO: should we only write if the property is None? ds.runtimes['extraction'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['extraction'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB @@ -987,7 +951,7 @@ def extract_sources_filter(self, image): xys = ndimage.center_of_mass(abs(image.data), labels, all_idx) x = np.array([xy[1] for xy in xys]) y = np.array([xy[0] for xy in xys]) - coords = image.wcs.wcs.pixel_to_world(x, y) + coords = image.get_wcs().wcs.pixel_to_world(x, y) ra = [c.ra.value for c in coords] dec = [c.dec.value for c in coords] diff --git a/pipeline/measuring.py b/pipeline/measuring.py index 0eab1912..aca2ee93 100644 --- a/pipeline/measuring.py +++ b/pipeline/measuring.py @@ -1,4 +1,3 @@ -import os import time import warnings import numpy as np @@ -17,7 +16,7 @@ from pipeline.parameters import Parameters from pipeline.data_store import DataStore -from util.util import parse_session, parse_bool +from util.util import parse_session, env_as_bool class ParsMeasurer(Parameters): @@ -172,7 +171,7 @@ def run(self, *args, **kwargs): try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -181,24 +180,22 @@ def run(self, *args, **kwargs): # get the provenance for this step: prov = ds.get_provenance('measuring', self.pars.get_critical_pars(), session=session) + detections = ds.get_detections(session=session) + if detections is None: + raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') + + cutouts = ds.get_cutouts(session=session) + if cutouts is None: + raise ValueError(f'Cannot find cutouts corresponding to the datastore inputs: {ds.get_inputs()}') + else: + cutouts.load_all_co_data() + # try to find some measurements in memory or in the database: measurements_list = ds.get_measurements(prov, session=session) # note that if measurements_list is found, there will not be an all_measurements appended to datastore! if measurements_list is None or len(measurements_list) == 0: # must create a new list of Measurements self.has_recalculated = True - # use the latest source list in the data store, - # or load using the provenance given in the - # data store's upstream_provs, or just use - # the most recent provenance for "detection" - detections = ds.get_detections(session=session) - - if detections is None: - raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') - - cutouts = ds.get_cutouts(session=session) - if cutouts is not None: - cutouts.load_all_co_data() # prepare the filter bank for this batch of cutouts if self._filter_psf_fwhm is None or self._filter_psf_fwhm != cutouts.sources.image.get_psf().fwhm_pixels: @@ -297,7 +294,14 @@ def run(self, *args, **kwargs): if m.bkg_mean != 0 and m.bkg_std > 0.1: norm_data = (m.sub_nandata - m.bkg_mean) / m.bkg_std # normalize else: - warnings.warn(f'Background mean= {m.bkg_mean}, std= {m.bkg_std}, normalization skipped!') + # only provide this warning if the offset is within the image + # otherwise, this measurement is never going to pass any cuts + # and we don't want to spam the logs with this warning + if ( + abs(m.offset_x) < m.sub_data.shape[1] and + abs(m.offset_y) < m.sub_data.shape[0] + ): + warnings.warn(f'Background mean= {m.bkg_mean}, std= {m.bkg_std}, normalization skipped!') norm_data = m.sub_nandata # no good background measurement, do not normalize! positives = np.sum(norm_data > self.pars.outlier_sigma) @@ -330,31 +334,35 @@ def run(self, *args, **kwargs): # TODO: add additional disqualifiers - m._upstream_bitflag = 0 - m._upstream_bitflag |= cutouts.bitflag - - ignore_bits = 0 - for badness in self.pars.bad_flag_exclude: - ignore_bits |= 2 ** BadnessConverter.convert(badness) - - m.disqualifier_scores['bad_flag'] = np.bitwise_and( - np.array(m.bitflag).astype('uint64'), - ~np.array(ignore_bits).astype('uint64'), - ) - # make sure disqualifier scores don't have any numpy types for k, v in m.disqualifier_scores.items(): if isinstance(v, np.number): m.disqualifier_scores[k] = v.item() measurements_list.append(m) + else: + [setattr(m, 'cutouts', cutouts) for m in measurements_list] # update with newest cutouts + + saved_measurements = [] + for m in measurements_list: + # regardless of wether we created these now, or loaded from DB, + # the bitflag should be updated based on the most recent data + m._upstream_bitflag = 0 + m._upstream_bitflag |= m.cutouts.bitflag + + ignore_bits = 0 + for badness in self.pars.bad_flag_exclude: + ignore_bits |= 2 ** BadnessConverter.convert(badness) + + m.disqualifier_scores['bad_flag'] = int(np.bitwise_and( + np.array(m.bitflag).astype('uint64'), + ~np.array(ignore_bits).astype('uint64'), + )) - saved_measurements = [] - for m in measurements_list: - threshold_comparison = self.compare_measurement_to_thresholds(m) - if threshold_comparison != "delete": # all disqualifiers are below threshold - m.is_bad = threshold_comparison == "bad" - saved_measurements.append(m) + threshold_comparison = self.compare_measurement_to_thresholds(m) + if threshold_comparison != "delete": # all disqualifiers are below threshold + m.is_bad = threshold_comparison == "bad" + saved_measurements.append(m) # add the resulting measurements to the data store ds.all_measurements = measurements_list # debugging only @@ -363,7 +371,7 @@ def run(self, *args, **kwargs): ds.sub_image.measurements = saved_measurements ds.runtimes['measuring'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['measuring'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB diff --git a/pipeline/parameters.py b/pipeline/parameters.py index a58c3977..84397286 100644 --- a/pipeline/parameters.py +++ b/pipeline/parameters.py @@ -488,6 +488,10 @@ def add_siblings(self, siblings): self.__sibling_parameters__.update(siblings) + def require_siblings(self): + """If not overriden, returns False. For subclasses that depend on siblings, this should return True.""" + return False + def get_critical_pars(self, ignore_siblings=False): """ Get a dictionary of the critical parameters. @@ -507,6 +511,8 @@ def get_critical_pars(self, ignore_siblings=False): # if there is no dictionary, or it is empty (or if asked to ignore siblings) just return the critical parameters if ignore_siblings or not self.__sibling_parameters__: return self.to_dict(critical=True, hidden=True) + elif not self.__sibling_parameters__ and self.require_siblings(): + raise ValueError("This object requires sibling parameters, but none were provided. Use add_siblings().") else: # a dictionary based on keys in __sibling_parameters__ with critical pars sub-dictionaries return { key: value.get_critical_pars(ignore_siblings=True) diff --git a/pipeline/photo_cal.py b/pipeline/photo_cal.py index 365ef85f..2c5a4731 100644 --- a/pipeline/photo_cal.py +++ b/pipeline/photo_cal.py @@ -1,4 +1,3 @@ -import os import time import numpy as np @@ -14,7 +13,7 @@ from util.exceptions import BadMatchException from util.logger import SCLogger -from util.util import parse_bool +from util.util import env_as_bool # TODO: Make max_catalog_mag and mag_range_catalog defaults be supplied # by the instrument, since there are going to be different sane defaults @@ -73,6 +72,9 @@ def __init__(self, **kwargs): def get_process_name(self): return 'photo_cal' + def require_siblings(self): + return True + class PhotCalibrator: def __init__(self, **kwargs): @@ -241,7 +243,7 @@ def run(self, *args, **kwargs): try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -250,6 +252,19 @@ def run(self, *args, **kwargs): # get the provenance for this step: prov = ds.get_provenance('extraction', self.pars.get_critical_pars(), session=session) + image = ds.get_image(session=session) + if image is None: + raise ValueError('Cannot find the image corresponding to the datastore inputs') + sources = ds.get_sources(session=session) + if sources is None: + raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') + psf = ds.get_psf(session=session) + if psf is None: + raise ValueError(f'Cannot find a psf corresponding to the datastore inputs: {ds.get_inputs()}') + wcs = ds.get_wcs(session=session) + if wcs is None: + raise ValueError(f'Cannot find a wcs for image {image.filepath}') + # try to find the world coordinates in memory or in the database: zp = ds.get_zp(prov, session=session) @@ -259,16 +274,6 @@ def run(self, *args, **kwargs): raise NotImplementedError( f"Currently only know how to calibrate to gaia_dr3, not " f"{self.pars.cross_match_catalog}" ) - image = ds.get_image(session=session) - - sources = ds.get_sources(session=session) - if sources is None: - raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') - - wcs = ds.get_wcs( session=session ) - if wcs is None: - raise ValueError( f'Cannot find a wcs for image {image.filepath}' ) - catname = self.pars.cross_match_catalog fetch_func = getattr(pipeline.catalog_tools, f'fetch_{catname}_excerpt') catexp = fetch_func( @@ -296,19 +301,21 @@ def run(self, *args, **kwargs): ds.zp = ZeroPoint( sources=ds.sources, provenance=prov, zp=zpval, dzp=dzpval, aper_cor_radii=sources.aper_rads, aper_cors=apercors ) - if ds.zp._upstream_bitflag is None: - ds.zp._upstream_bitflag = 0 - ds.zp._upstream_bitflag |= sources.bitflag - ds.zp._upstream_bitflag |= wcs.bitflag - ds.image.zero_point_estimate = ds.zp.zp # TODO: should we only write if the property is None? - # TODO: we should also add a limiting magnitude calculation here. + # TODO: I'm putting a stupid placeholder instead of actual limiting magnitude, please fix this! + ds.image.lim_mag_estimate = ds.zp.zp - 2.5 * np.log10(5.0 * ds.image.bkg_rms_estimate) ds.runtimes['photo_cal'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['photo_cal'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB + # update the bitflag with the upstreams + ds.zp._upstream_bitflag = 0 + ds.zp._upstream_bitflag |= sources.bitflag # includes badness from Image as well + ds.zp._upstream_bitflag |= psf.bitflag + ds.zp._upstream_bitflag |= wcs.bitflag + except Exception as e: ds.catch_exception(e) finally: diff --git a/pipeline/preprocessing.py b/pipeline/preprocessing.py index 1b70999d..1472006c 100644 --- a/pipeline/preprocessing.py +++ b/pipeline/preprocessing.py @@ -1,4 +1,3 @@ -import os import pathlib import time @@ -16,7 +15,7 @@ from util.config import Config from util.logger import SCLogger -from util.util import parse_bool +from util.util import env_as_bool class ParsPreprocessor(Parameters): @@ -99,7 +98,7 @@ def run( self, *args, **kwargs ): try: # catch any exceptions and save them in the datastore t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -299,7 +298,7 @@ def run( self, *args, **kwargs ): ds.image = image ds.runtimes['preprocessing'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['preprocessing'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB diff --git a/pipeline/ref_maker.py b/pipeline/ref_maker.py new file mode 100644 index 00000000..411523f6 --- /dev/null +++ b/pipeline/ref_maker.py @@ -0,0 +1,566 @@ +import datetime +import time + +import numpy as np +import sqlalchemy as sa +from sqlalchemy.exc import IntegrityError + +from pipeline.parameters import Parameters +from pipeline.coaddition import CoaddPipeline +from pipeline.top_level import Pipeline + +from models.base import SmartSession +from models.provenance import Provenance +from models.reference import Reference +from models.exposure import Exposure +from models.image import Image +from models.refset import RefSet + +from util.config import Config +from util.logger import SCLogger +from util.util import parse_session, listify +from util.radec import parse_sexigesimal_degrees + + +class ParsRefMaker(Parameters): + def __init__(self, **kwargs): + super().__init__() + + self.name = self.add_par( + 'name', + 'default', + str, + 'Name of the reference set. ', + critical=False, # the name of the refset is not in the Reference provenance! + # this means multiple refsets can refer to the same Reference provenance + ) + + self.description = self.add_par( + 'description', + '', + str, + 'Description of the reference set. ', + critical=False, + ) + + self.allow_append = self.add_par( + 'allow_append', + True, + bool, + 'If True, will append new provenances to an existing reference set with the same name. ' + 'If False, will raise an error if a reference set with the same name ' + 'and a different provenance already exists', + critical=False, # can decide to turn this option on or off as an administrative decision + ) + + self.start_time = self.add_par( + 'start_time', + None, + (None, str, float, datetime.datetime), + 'Only use images taken after this time (inclusive). ' + 'Time format can be MJD float, ISOT string, or datetime object. ' + 'If None, will not limit the start time. ', + critical=True, + ) + + self.end_time = self.add_par( + 'end_time', + None, + (None, str, float, datetime.datetime), + 'Only use images taken before this time (inclusive). ' + 'Time format can be MJD float, ISOT string, or datetime object. ' + 'If None, will not limit the end time. ', + critical=True, + ) + + self.instruments = self.add_par( + 'instruments', + None, + (None, list), + 'Only use images from these instruments. If None, will use all instruments. ' + 'If given as a list, will use any of the instruments in the list. ' + 'In both these cases, cross-instrument references will be made. ' + 'To make sure single-instrument references are made, make a different refset ' + 'with a single item on this list, one for each instrument. ' + 'This does not have a default value, but you MUST supply a list with at least one instrument ' + 'in order to get a reference provenance and create a reference set. ', + critical=True, + ) + + self.filters = self.add_par( + 'filters', + None, + (None, list), + 'Only use images with these filters. If None, will not limit the filters. ' + 'If given as a list, will use any of the filters in the list. ' + 'For multiple instruments, can match any filter to any instrument. ', + critical=True, + ) + + self.projects = self.add_par( + 'projects', + None, + (None, list), + 'Only use images from these projects. If None, will not limit the projects. ' + 'If given as a list, will use any of the projects in the list. ', + critical=True, + ) + + self.__image_query_pars__ = ['airmass', 'background', 'seeing', 'lim_mag', 'exp_time'] + + for name in self.__image_query_pars__: + for min_max in ['min', 'max']: + self.add_limit_parameter(name, min_max) + + self.__docstrings__['min_lim_mag'] = ('Only use images with lim_mag larger (fainter) than this. ' + 'If None, will not limit the minimal lim_mag. ') + self.__docstrings__['max_lim_mag'] = ('Only use images with lim_mag smaller (brighter) than this. ' + 'If None, will not limit the maximal lim_mag. ') + + self.min_number = self.add_par( + 'min_number', + 1, + int, + 'Construct a reference only if there are at least this many images that pass all other criteria. ', + critical=True, + ) + + self.max_number = self.add_par( + 'max_number', + None, + (None, int), + 'If there are more than this many images, pick the ones with the highest "quality". ', + critical=True, + ) + + self.seeing_quality_factor = self.add_par( + 'seeing_quality_factor', + 3.0, + float, + 'linear combination coefficient for adding limiting magnitude and seeing FWHM ' + 'when calculating the "image quality" used to rank images. ', + critical=True, + ) + + self.save_new_refs = self.add_par( + 'save_new_refs', + True, + bool, + 'If True, will save the coadd image and commit it and the newly created reference to the database. ' + 'If False, will only return it. ', + critical=False, + ) + + self._enforce_no_new_attrs = True # lock against new parameters + + self.override(kwargs) + + def add_limit_parameter(self, name, min_max='min'): + """Add a parameter in a systematic way. """ + if min_max not in ['min', 'max']: + raise ValueError('min_max must be either "min" or "max"') + compare = 'larger' if min_max == 'min' else 'smaller' + setattr( + self, + f'{min_max}_{name}', + self.add_par( + f'{min_max}_{name}', + None, + (None, float), + f'Only use images with {name} {compare} than this value. ' + f'If None, will not limit the {min_max}imal {name}.', + critical=True, + ) + ) + + def get_process_name(self): + return 'referencing' + + +class RefMaker: + def __init__(self, **kwargs): + """ + Initialize a reference maker object. + + The possible keywords that can be given are: maker, pipeline, coaddition. Each should be a dictionary. + + The object will load the config file and use the following hierarchy to set the parameters: + - first loads up the regular pipeline parameters, namely those for preprocessing and extraction. + - override those with the parameters given by the "referencing" dictionary in the config file. + - override those with kwargs['pipeline'] that can have "preprocessing" or "extraction" keys. + - parameters for the coaddition step, and the extraction done on the coadd image are taken from "coaddition" + - those are overriden by the "referencing.coaddition" dictionary in the config file + - those are overriden by the kwargs['coaddition'] dictionary, if it exists. + - the parameters to the reference maker its (e.g., how to choose images) are given from the + config['referencing.maker'] dictionary and are overriden by the kwargs['maker'] dictionary. + + The maker contains a pipeline object, that doesn't do any work, but is instantiated so it can build up the + provenances of the images and their products, that go into the coaddition. + Those images need to already exist in the database before calling run(). + Pass kwargs into the pipeline object using kwargs['pipeline']. + TODO: what about multiple instruments that go into the coaddition? we'd need multiple pipeline objects + in order to have difference parameter sets for preprocessing/extraction for each instrument. + The maker also contains a coadd_pipeline object, that has two roles: one is to build the provenances of the + coadd image and the products of that image (extraction on the coadd) and the second is to actually + do the work of coadding the chosen images. + Pass kwargs into this object using kwargs['coaddition']. + The choice of which images are loaded into the reference coadd is determined by the parameters object of the + maker itself (and the provenances of the images and their products). + To set these parameters, use the "referencing.maker" dictionary in the config, or pass them in kwargs['maker']. + """ + # first break off some pieces of the kwargs dict + maker_overrides = kwargs.pop('maker', {}) # to set the parameters of the reference maker itself + pipe_overrides = kwargs.pop('pipeline', {}) # to allow overriding the regular image pipeline + coadd_overrides = kwargs.pop('coaddition', {}) # to allow overriding the coaddition pipeline + + if len(kwargs) > 0: + raise ValueError(f'Unknown parameters given to RefMaker: {kwargs.keys()}') + + # now read the config file + config = Config.get() + + # initialize an object to get the provenances of the regular images and their products + pipe_dict = config.value('referencing.pipeline', {}) # this is the reference pipeline override + pipe_dict.update(pipe_overrides) + self.pipeline = Pipeline(**pipe_dict) # internally loads regular pipeline config, overrides with pipe_dict + + coadd_dict = config.value('referencing.coaddition', {}) # allow overrides from config's referencing.coaddition + coadd_dict.update(coadd_overrides) # allow overrides from kwargs['coaddition'] + self.coadd_pipeline = CoaddPipeline(**coadd_dict) # coaddition parameters, overrides with coadd_dict + + maker_dict = config.value('referencing.maker') + maker_dict.update(maker_overrides) # user can provide override arguments in kwargs + self.pars = ParsRefMaker(**maker_dict) # initialize without the pipeline/coaddition parameters + + # first, make sure we can assemble the provenances up to extraction: + self.im_provs = None # the provenances used to make images going into the reference (these are coadds!) + self.ex_provs = None # the provenances used to make other products like SourceLists, that go into the reference + self.coadd_im_prov = None # the provenance used to make the coadd image + self.coadd_ex_prov = None # the provenance used to make the products of the coadd image + self.ref_upstream_hash = None # a hash identifying all upstreams of the reference provenance + self.ref_prov = None # the provenance of the reference itself + self.refset = None # the RefSet object that was found / created + + # these attributes tell us the place in the sky where we want to look for objects (given to run()) + # optionally it also specifies which filter we want the reference to be in + self.ra = None # in degrees + self.dec = None # in degrees + self.target = None # the name of the target / field ID / Object ID + self.section_id = None # a string with the section ID + self.filter = None # a string with the (short) name of the filter + + def setup_provenances(self, session=None): + """Make the provenances for the images and all their products, including the coadd image. + + These are used both to establish the provenance of the reference itself, + and to look for images and associated products (like SourceLists) when + building the reference. + """ + if self.pars.instruments is None or len(self.pars.instruments) == 0: + raise ValueError('No instruments given to RefMaker!') + + self.im_provs = [] + self.ex_provs = [] + + for inst in self.pars.instruments: + load_exposure = Exposure.make_provenance(inst) + pars = self.pipeline.preprocessor.pars.get_critical_pars() + preprocessing = Provenance( + process='preprocessing', + code_version=load_exposure.code_version, # TODO: allow loading versions for each process + parameters=pars, + upstreams=[load_exposure], + is_testing='test_parameter' in pars, + ) + pars = self.pipeline.extractor.pars.get_critical_pars() # includes parameters of siblings + extraction = Provenance( + process='extraction', + code_version=preprocessing.code_version, # TODO: allow loading versions for each process + parameters=pars, + upstreams=[preprocessing], + is_testing='test_parameter' in pars, + ) + + # the exposure provenance is not included in the reference provenance's upstreams + self.im_provs.append(preprocessing) + self.ex_provs.append(extraction) + + upstreams = self.im_provs + self.ex_provs # all the provenances that go into the coadd + # TODO: we are using preprocess.code_version but should really load the right code_version for this process. + coadd_provs = self.coadd_pipeline.make_provenance_tree(upstreams, preprocessing.code_version, session=session) + self.coadd_im_prov = coadd_provs['coaddition'] + self.coadd_ex_prov = coadd_provs['extraction'] + + pars = self.pars.get_critical_pars() + self.ref_prov = Provenance( + process=self.pars.get_process_name(), + code_version=self.coadd_im_prov.code_version, # TODO: allow loading versions for each process + parameters=pars, + upstreams=[self.coadd_im_prov, self.coadd_ex_prov], + is_testing='test_parameter' in pars, + ) + + # this hash uniquely identifies all the preprocessing and extraction hashes in this provenance's upstreams + self.ref_upstream_hash = self.ref_prov.get_combined_upstream_hash() + # NOTE: we could have just used the coadd_ex_prov hash, because that uniquely identifies the coadd_im_prov + # (it is an upstream) and through that the preprocessing and extraction provenances of the regular images. + # but I am not sure how the provenance tree will look like in the future, I am leaving this additional hash + # here to be safe. The important part is that this hash must be singular for each RefSet, so that the + # downstreams of the subtractions will have a well-defined provenance, one for each RefSet. + + def parse_arguments(self, *args, **kwargs): + """Figure out if the input parameters are given as coordinates or as target + section ID pairs. + + Possible combinations: + - float + float + string: interpreted as RA/Dec in degrees + - str + str: try to interpret as sexagesimal (RA as hours, Dec as degrees) + if it fails, will interpret as target + section ID + # TODO: can we identify a reference with only a target/field ID without a section ID? Issue #320 + In addition to the first two arguments, can also supply a filter name as a string + and can provide a session object as an argument (in any position) to be used and kept open + for the entire run. If not given a session, will open a new one and close it when done using it internally. + + Alternatively, can provide named arguments with the same combinations for either + (ra, dec) or (target, section_id) and filter. + + Returns + ------- + session: sqlalchemy.orm.session.Session object or None + The session object, if it was passed in as a positional argument. + If not given, the ref maker will just open and close sessions internally + when needed. + """ + self.ra = None + self.dec = None + self.target = None + self.section_id = None + self.filter = None + + args, kwargs, session = parse_session(*args, **kwargs) # first pick out any sessions + + if len(args) == 3: + if not isinstance(args[2], str): + raise ValueError('Third argument must be a string, the filter name!') + self.filter = args[2] + args = args[:2] # remove the last one + + if len(args) == 2: + if isinstance(args[0], (float, int, np.number)) and isinstance(args[1], (float, int, np.number)): + self.ra = float(args[0]) + self.dec = float(args[1]) + if isinstance(args[0], str) and isinstance(args[1], str): + try: + self.ra = parse_sexigesimal_degrees(args[0], hours=True) + self.dec = parse_sexigesimal_degrees(args[1], hours=False) + except ValueError: + self.target, self.section_id = args[0], args[1] + elif len(args) == 0: # parse kwargs instead! + if 'ra' in kwargs and 'dec' in kwargs: + self.ra = kwargs.pop('ra') + if isinstance(self.ra, str): + self.ra = parse_sexigesimal_degrees(self.ra, hours=True) + + self.dec = kwargs.pop('dec') + if isinstance(self.dec, str): + self.dec = parse_sexigesimal_degrees(self.dec, hours=False) + + elif 'target' in kwargs and 'section_id' in kwargs: + self.target = kwargs.pop('target') + self.section_id = kwargs.pop('section_id') + else: + raise ValueError('Cannot find ra/dec or target/section_id in any of the inputs! ') + + if 'filter' in kwargs: + self.filter = kwargs.pop('filter') + + else: + raise ValueError('Invalid number of arguments given to RefMaker.parse_arguments()') + + if self.filter is None: + raise ValueError('No filter given to RefMaker.parse_arguments()!') + + return session + + def make_refset(self, session=None): + """Create or load an existing RefSet with the required name. + + Will also make all the required provenances (using the config) and + possibly append the reference provenance to the list of provenances + on the RefSet. + """ + with SmartSession(session) as dbsession: + self.setup_provenances(session=dbsession) + + # first merge the reference provenance + self.ref_prov = self.ref_prov.merge_concurrent(session=dbsession, commit=True) + + # now load or create a RefSet + for i in range(5): # a concurrent merge sort of loop + self.refset = dbsession.scalars(sa.select(RefSet).where(RefSet.name == self.pars.name)).first() + + if self.refset is not None: + break + else: # not found any RefSet with this name + try: + self.refset = RefSet( + name=self.pars.name, + description=self.pars.description, + upstream_hash=self.ref_upstream_hash, + ) + dbsession.add(self.refset) + dbsession.commit() + except IntegrityError as e: + # there was a violation on unique constraint on the "name" column: + if 'duplicate key value violates unique constraint "ix_refsets_name' in str(e): + session.rollback() + time.sleep(0.1 * 2 ** i) # exponential sleep + else: + raise e + else: # if we didn't break out of the loop, there must have been some integrity error + raise e + + if self.refset is None: + raise RuntimeError(f'Failed to find or create a RefSet with the name "{self.pars.name}"!') + + if self.refset.upstream_hash != self.ref_upstream_hash: + raise RuntimeError( + f'Found a RefSet with the name "{self.pars.name}", but it has a different upstream_hash!' + ) + + # If the provenance is not already on the RefSet, add it (or raise, if allow_append=False) + if self.ref_prov.id not in [p.id for p in self.refset.provenances]: + if self.pars.allow_append: + prov_list = self.refset.provenances + prov_list.append(self.ref_prov) + self.refset.provenances = prov_list # not sure if appending directly will trigger an update to DB + dbsession.commit() + else: + raise RuntimeError( + f'Found a RefSet with the name "{self.pars.name}", but it has a different provenance! ' + f'Use "allow_append" parameter to add new provenances to this RefSet. ' + ) + + def run(self, *args, **kwargs): + """Check if a reference exists for the given coordinates/field ID, and filter, and make it if it is missing. + + Will check if a RefSet exists with the same provenance and name, and if it doesn't, will create a new + RefSet with these properties, to keep track of the reference provenances. + + Arguments specifying where in the sky to look for / create the reference are parsed by parse_arguments(). + Same is true for the filter choice. + The remaining policy regarding which images to pick, and what provenance to use to find references, + is defined by the parameters object of self and of self.pipeline. + + If one of the inputs is a session, will use that in the entire process. + Otherwise, will open internal sessions and close them whenever they are not needed. + + Will return a Reference, or None in case it doesn't exist and cannot be created + (e.g., because there are not enough images that pass the criteria). + """ + session = self.parse_arguments(*args, **kwargs) + + with SmartSession(session) as dbsession: + self.make_refset(session=dbsession) + + # look for the reference at the given location in the sky (via ra/dec or target/section_id) + ref = Reference.get_references( + ra=self.ra, + dec=self.dec, + target=self.target, + section_id=self.section_id, + filter=self.filter, + provenance_ids=self.ref_prov.id, + session=dbsession, + ) + + if ref: # found a reference, can skip the next part of the code! + if len(ref) == 0: + return None + elif len(ref) == 1: + return ref[0] + else: + raise RuntimeError( + f'Found multiple references with the same provenance {self.ref_prov.id} and location!' + ) + ############### no reference found, need to build one! ################ + + # first get all the images that could be used to build the reference + images = [] # can get images from different instruments + for inst in self.pars.instrument: + prov = [p for p in self.im_provs if p.upstreams[0].parameters['instrument'] == inst] + if len(prov) == 0: + raise RuntimeError(f'Cannot find a provenance for instrument "{inst}" in im_provs!') + if len(prov) > 1: + raise RuntimeError(f'Found multiple provenances for instrument "{inst}" in im_provs!') + prov = prov[0] + + query_pars = dict( + instrument=inst, + ra=self.ra, # can be None! + dec=self.dec, # can be None! + target=self.target, # can be None! + section_id=self.section_id, # can be None! + filter=self.pars.filters, # can be None! + project=self.pars.project, # can be None! + min_dateobs=self.pars.start_time, + max_dateobs=self.pars.end_time, + seeing_quality_factor=self.pars.seeing_quality_factor, + provenance_ids=prov.id, + ) + + for key in self.pars.__image_query_pars__: + for min_max in ['min', 'max']: + query_pars[f'{min_max}_{key}'] = getattr(self.pars, f'{min_max}_{key}') # can be None! + + # get the actual images that match the query + + images += dbsession.scalars(Image.query_images(**query_pars).limit(self.pars.max_number)).all() + + if len(images) < self.pars.min_number: + SCLogger.info(f'Found {len(images)} images, need at least {self.pars.min_number} to make a reference!') + return None + + # note that if there are multiple instruments, each query may load the max number of images, + # that's why we must also limit the number of images after all queries have returned. + if len(images) > self.pars.max_number: + coeff = abs(self.pars.seeing_quality_factor) # abs is used to make sure the coefficient is negative + for im in images: + im.quality = im.lim_mag_estimate - coeff * im.fwhm_estimate + + # sort the images by the quality + images = sorted(images, key=lambda x: x.quality, reverse=True) + images = images[:self.pars.max_number] + + # make the reference (note that we are out of the session block, to release it while we coadd) + images = sorted(images, key=lambda x: x.mjd) # sort the images in chronological order for coaddition + + # load the extraction products of these images using the ex_provs + for im in images: + im.load_products(self.ex_provs, session=dbsession) + prods = {p: getattr(im, p) for p in ['sources', 'psf', 'bg', 'wcs', 'zp']} + if any([p is None for p in prods.values()]): + raise RuntimeError( + f'Image {im} is missing products {prods} for coaddition! ' + f'Make sure to produce products using the provenances in ex_provs: ' + f'{self.ex_provs}' + ) + + # release the session when making the coadd image + + coadd_image = self.coadd_pipeline.run(images) + + ref = Reference(image=coadd_image, provenance=self.ref_prov) + + if self.pars.save_new_refs: + with SmartSession(session) as dbsession: + ref.image.save(overwrite=True) + ref.image.sources.save(overwrite=True) + ref.image.psf.save(overwrite=True) + ref.image.bg.save(overwrite=True) + ref.image.wcs.save(overwrite=True) + # zp is not a FileOnDiskMixin! + + ref = ref.merge_all(dbsession) + dbsession.commit() + + return ref diff --git a/pipeline/subtraction.py b/pipeline/subtraction.py index dd854486..2a63e530 100644 --- a/pipeline/subtraction.py +++ b/pipeline/subtraction.py @@ -1,19 +1,21 @@ -import os import time import numpy as np +import sqlalchemy as sa + from pipeline.parameters import Parameters from pipeline.data_store import DataStore from models.base import SmartSession from models.image import Image +from models.refset import RefSet from improc.zogy import zogy_subtract, zogy_add_weights_flags from improc.inpainting import Inpainter from improc.alignment import ImageAligner from improc.tools import sigma_clipping -from util.util import parse_bool +from util.util import env_as_bool class ParsSubtractor(Parameters): @@ -26,9 +28,16 @@ def __init__(self, **kwargs): 'Which subtraction method to use. Possible values are: "hotpants", "zogy". ' ) + self.refset = self.add_par( + 'refset', + None, + (None, str), + 'The name of the reference set to use for getting a reference image. ' + ) + self.alignment = self.add_par( 'alignment', - {'method:': 'swarp', 'to_index': 'new'}, + {'method': 'swarp', 'to_index': 'new'}, dict, 'How to align the reference image to the new image. This will be ingested by ImageAligner. ' ) @@ -190,6 +199,9 @@ def _subtract_zogy(self, new_image, ref_image): output['outwt'] = outwt output['outfl'] = outfl + # convert flux based into magnitude based zero point + output['zero_point'] = 2.5 * np.log10(output['zero_point']) + return output def _subtract_hotpants(self, new_image, ref_image): @@ -234,7 +246,7 @@ def run(self, *args, **kwargs): try: t_start = time.perf_counter() - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc tracemalloc.reset_peak() # start accounting for the peak memory usage from here @@ -242,8 +254,16 @@ def run(self, *args, **kwargs): # get the provenance for this step: with SmartSession(session) as session: - # look for a reference that has to do with the current image - ref = ds.get_reference(session=session) + # look for a reference that has to do with the current image and refset + if self.pars.refset is None: + raise ValueError('No reference set given for subtraction') + refset = session.scalars(sa.select(RefSet).where(RefSet.name == self.pars.refset)).first() + if refset is None: + raise ValueError(f'Cannot find a reference set with name {self.pars.refset}') + + # TODO: we can add additional parameters of get_reference() that come from + # the subtraction config, such as skip_bad, match_filter, ignore_target_and_section, min_overlap + ref = ds.get_reference(refset.provenances, session=session) if ref is None: raise ValueError( f'Cannot find a reference image corresponding to the datastore inputs: {ds.get_inputs()}' @@ -312,19 +332,33 @@ def run(self, *args, **kwargs): sub_image.subtraction_output = outdict # save the full output for debugging - if sub_image._upstream_bitflag is None: - sub_image._upstream_bitflag = 0 - sub_image._upstream_bitflag |= ds.sources.bitflag + # TODO: can we get better estimates from our subtraction outdict? Issue #312 + sub_image.fwhm_estimate = new_image.fwhm_estimate + # if the subtraction does not provide an estimate of the ZP, use the one from the new image + sub_image.zero_point_estimate = outdict.get('zero_point', new_image.zp.zp) + sub_image.lim_mag_estimate = new_image.lim_mag_estimate + + # if the subtraction does not provide an estimate of the background, use sigma clipping + if 'bkg_mean' not in outdict or 'bkg_rms' not in outdict: + mu, sig = sigma_clipping(sub_image.data) + sub_image.bkg_mean_estimate = outdict.get('bkg_mean', mu) + sub_image.bkg_rms_estimate = outdict.get('bkg_rms', sig) + + sub_image._upstream_bitflag = 0 sub_image._upstream_bitflag |= ds.image.bitflag + sub_image._upstream_bitflag |= ds.sources.bitflag + sub_image._upstream_bitflag |= ds.psf.bitflag + sub_image._upstream_bitflag |= ds.bg.bitflag sub_image._upstream_bitflag |= ds.wcs.bitflag sub_image._upstream_bitflag |= ds.zp.bitflag + if 'ref_image' in locals(): sub_image._upstream_bitflag |= ref_image.bitflag ds.sub_image = sub_image ds.runtimes['subtraction'] = time.perf_counter() - t_start - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): import tracemalloc ds.memory_usages['subtraction'] = tracemalloc.get_traced_memory()[1] / 1024 ** 2 # in MB diff --git a/pipeline/top_level.py b/pipeline/top_level.py index 9ed8ba2a..2455cffa 100644 --- a/pipeline/top_level.py +++ b/pipeline/top_level.py @@ -1,4 +1,3 @@ -import os import datetime import time import warnings @@ -18,13 +17,13 @@ from models.base import SmartSession from models.provenance import Provenance -from models.reference import Reference +from models.refset import RefSet from models.exposure import Exposure from models.report import Report from util.config import Config from util.logger import SCLogger -from util.util import parse_bool +from util.util import env_as_bool # describes the pipeline objects that are used to produce each step of the pipeline # if multiple objects are used in one step, replace the string with a sub-dictionary, @@ -68,6 +67,14 @@ def __init__(self, **kwargs): critical=False, ) + self.save_at_finish = self.add_par( + 'save_at_finish', + True, + bool, + 'Save the final products to the database and disk', + critical=False, + ) + self._enforce_no_new_attrs = True # lock against new parameters self.override(kwargs) @@ -75,39 +82,39 @@ def __init__(self, **kwargs): class Pipeline: def __init__(self, **kwargs): - self.config = Config.get() + config = Config.get() # top level parameters - self.pars = ParsPipeline(**(self.config.value('pipeline', {}))) + self.pars = ParsPipeline(**(config.value('pipeline', {}))) self.pars.augment(kwargs.get('pipeline', {})) # dark/flat and sky subtraction tools - preprocessing_config = self.config.value('preprocessing', {}) + preprocessing_config = config.value('preprocessing', {}) preprocessing_config.update(kwargs.get('preprocessing', {})) self.pars.add_defaults_to_dict(preprocessing_config) self.preprocessor = Preprocessor(**preprocessing_config) # source detection ("extraction" for the regular image!) - extraction_config = self.config.value('extraction.sources', {}) + extraction_config = config.value('extraction.sources', {}) extraction_config.update(kwargs.get('extraction', {}).get('sources', {})) extraction_config.update({'measure_psf': True}) self.pars.add_defaults_to_dict(extraction_config) self.extractor = Detector(**extraction_config) # background estimation using either sep or other methods - background_config = self.config.value('extraction.bg', {}) + background_config = config.value('extraction.bg', {}) background_config.update(kwargs.get('extraction', {}).get('bg', {})) self.pars.add_defaults_to_dict(background_config) self.backgrounder = Backgrounder(**background_config) # astrometric fit using a first pass of sextractor and then astrometric fit to Gaia - astrometor_config = self.config.value('extraction.wcs', {}) + astrometor_config = config.value('extraction.wcs', {}) astrometor_config.update(kwargs.get('extraction', {}).get('wcs', {})) self.pars.add_defaults_to_dict(astrometor_config) self.astrometor = AstroCalibrator(**astrometor_config) # photometric calibration: - photometor_config = self.config.value('extraction.zp', {}) + photometor_config = config.value('extraction.zp', {}) photometor_config.update(kwargs.get('extraction', {}).get('zp', {})) self.pars.add_defaults_to_dict(photometor_config) self.photometor = PhotCalibrator(**photometor_config) @@ -125,26 +132,26 @@ def __init__(self, **kwargs): self.photometor.pars.add_siblings(siblings) # reference fetching and image subtraction - subtraction_config = self.config.value('subtraction', {}) + subtraction_config = config.value('subtraction', {}) subtraction_config.update(kwargs.get('subtraction', {})) self.pars.add_defaults_to_dict(subtraction_config) self.subtractor = Subtractor(**subtraction_config) # source detection ("detection" for the subtracted image!) - detection_config = self.config.value('detection', {}) + detection_config = config.value('detection', {}) detection_config.update(kwargs.get('detection', {})) self.pars.add_defaults_to_dict(detection_config) self.detector = Detector(**detection_config) self.detector.pars.subtraction = True # produce cutouts for detected sources: - cutting_config = self.config.value('cutting', {}) + cutting_config = config.value('cutting', {}) cutting_config.update(kwargs.get('cutting', {})) self.pars.add_defaults_to_dict(cutting_config) self.cutter = Cutter(**cutting_config) # measure photometry, analytical cuts, and deep learning models on the Cutouts: - measuring_config = self.config.value('measuring', {}) + measuring_config = config.value('measuring', {}) measuring_config.update(kwargs.get('measuring', {})) self.pars.add_defaults_to_dict(measuring_config) self.measurer = Measurer(**measuring_config) @@ -156,7 +163,7 @@ def override_parameters(self, **kwargs): if isinstance(PROCESS_OBJECTS[key], dict): for sub_key, sub_value in PROCESS_OBJECTS[key].items(): if sub_key in value: - getattr(self, PROCESS_OBJECTS[key][sub_value]).pars.override(value[sub_key]) + getattr(self, sub_value).pars.override(value[sub_key]) elif isinstance(PROCESS_OBJECTS[key], str): getattr(self, PROCESS_OBJECTS[key]).pars.override(value) else: @@ -198,7 +205,7 @@ def setup_datastore(self, *args, **kwargs): ds, session = DataStore.from_args(*args, **kwargs) if ds.exposure is None: - raise RuntimeError('Not sure if there is a way to run this pipeline method without an exposure!') + raise RuntimeError('Cannot run this pipeline method without an exposure!') try: # must make sure the exposure is on the DB ds.exposure = ds.exposure.merge_concurrent(session=session) @@ -263,88 +270,110 @@ def run(self, *args, **kwargs): ds : DataStore The DataStore object that includes all the data products. """ - ds, session = self.setup_datastore(*args, **kwargs) - if ds.image is not None: - SCLogger.info(f"Pipeline starting for image {ds.image.id} ({ds.image.filepath})") - elif ds.exposure is not None: - SCLogger.info(f"Pipeline starting for exposure {ds.exposure.id} ({ds.exposure}) section {ds.section_id}") - else: - SCLogger.info(f"Pipeline starting with args {args}, kwargs {kwargs}") - - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): - # ref: https://docs.python.org/3/library/tracemalloc.html#record-the-current-and-peak-size-of-all-traced-memory-blocks - import tracemalloc - tracemalloc.start() # trace the size of memory that is being used - - with warnings.catch_warnings(record=True) as w: - ds.warnings_list = w # appends warning to this list as it goes along - # run dark/flat preprocessing, cut out a specific section of the sensor - - SCLogger.info(f"preprocessor") - ds = self.preprocessor.run(ds, session) - ds.update_report('preprocessing', session) - SCLogger.info(f"preprocessing complete: image id = {ds.image.id}, filepath={ds.image.filepath}") - - # extract sources and make a SourceList and PSF from the image - SCLogger.info(f"extractor for image id {ds.image.id}") - ds = self.extractor.run(ds, session) - ds.update_report('extraction', session) - - # find the background for this image - SCLogger.info(f"backgrounder for image id {ds.image.id}") - ds = self.backgrounder.run(ds, session) - ds.update_report('extraction', session) - - # find astrometric solution, save WCS into Image object and FITS headers - SCLogger.info(f"astrometor for image id {ds.image.id}") - ds = self.astrometor.run(ds, session) - ds.update_report('extraction', session) - - # cross-match against photometric catalogs and get zero point, save into Image object and FITS headers - SCLogger.info(f"photometor for image id {ds.image.id}") - ds = self.photometor.run(ds, session) - ds.update_report('extraction', session) - - if self.pars.save_before_subtraction: - t_start = time.perf_counter() - try: - SCLogger.info(f"Saving intermediate image for image id {ds.image.id}") - ds.save_and_commit(session=session) - except Exception as e: - ds.update_report('save intermediate', session) - SCLogger.error(f"Failed to save intermediate image for image id {ds.image.id}") - SCLogger.error(e) - raise e - - ds.runtimes['save_intermediate'] = time.perf_counter() - t_start - - # fetch reference images and subtract them, save subtracted Image objects to DB and disk - SCLogger.info(f"subtractor for image id {ds.image.id}") - ds = self.subtractor.run(ds, session) - ds.update_report('subtraction', session) - - # find sources, generate a source list for detections - SCLogger.info(f"detector for image id {ds.image.id}") - ds = self.detector.run(ds, session) - ds.update_report('detection', session) - - # make cutouts of all the sources in the "detections" source list - SCLogger.info(f"cutter for image id {ds.image.id}") - ds = self.cutter.run(ds, session) - ds.update_report('cutting', session) - - # extract photometry and analytical cuts - SCLogger.info(f"measurer for image id {ds.image.id}") - ds = self.measurer.run(ds, session) - ds.update_report('measuring', session) - - # measure deep learning models on the cutouts/measurements - # TODO: add this... - - # TODO: add a saving step at the end too? - - ds.finalize_report(session) + try: # first make sure we get back a datastore, even an empty one + ds, session = self.setup_datastore(*args, **kwargs) + except Exception as e: + return DataStore.catch_failure_to_parse(e, *args) + try: + if ds.image is not None: + SCLogger.info(f"Pipeline starting for image {ds.image.id} ({ds.image.filepath})") + elif ds.exposure is not None: + SCLogger.info(f"Pipeline starting for exposure {ds.exposure.id} ({ds.exposure}) section {ds.section_id}") + else: + SCLogger.info(f"Pipeline starting with args {args}, kwargs {kwargs}") + + if env_as_bool('SEECHANGE_TRACEMALLOC'): + # ref: https://docs.python.org/3/library/tracemalloc.html#record-the-current-and-peak-size-of-all-traced-memory-blocks + import tracemalloc + tracemalloc.start() # trace the size of memory that is being used + + with warnings.catch_warnings(record=True) as w: + ds.warnings_list = w # appends warning to this list as it goes along + # run dark/flat preprocessing, cut out a specific section of the sensor + + SCLogger.info(f"preprocessor") + ds = self.preprocessor.run(ds, session) + ds.update_report('preprocessing', session=None) + SCLogger.info(f"preprocessing complete: image id = {ds.image.id}, filepath={ds.image.filepath}") + + # extract sources and make a SourceList and PSF from the image + SCLogger.info(f"extractor for image id {ds.image.id}") + ds = self.extractor.run(ds, session) + ds.update_report('extraction', session=None) + + # find the background for this image + SCLogger.info(f"backgrounder for image id {ds.image.id}") + ds = self.backgrounder.run(ds, session) + ds.update_report('extraction', session=None) + + # find astrometric solution, save WCS into Image object and FITS headers + SCLogger.info(f"astrometor for image id {ds.image.id}") + ds = self.astrometor.run(ds, session) + ds.update_report('extraction', session=None) + + # cross-match against photometric catalogs and get zero point, save into Image object and FITS headers + SCLogger.info(f"photometor for image id {ds.image.id}") + ds = self.photometor.run(ds, session) + ds.update_report('extraction', session=None) + + if self.pars.save_before_subtraction: + t_start = time.perf_counter() + try: + SCLogger.info(f"Saving intermediate image for image id {ds.image.id}") + ds.save_and_commit(session=session) + except Exception as e: + ds.update_report('save intermediate', session=None) + SCLogger.error(f"Failed to save intermediate image for image id {ds.image.id}") + SCLogger.error(e) + raise e + + ds.runtimes['save_intermediate'] = time.perf_counter() - t_start + + # fetch reference images and subtract them, save subtracted Image objects to DB and disk + SCLogger.info(f"subtractor for image id {ds.image.id}") + ds = self.subtractor.run(ds, session) + ds.update_report('subtraction', session=None) + + # find sources, generate a source list for detections + SCLogger.info(f"detector for image id {ds.image.id}") + ds = self.detector.run(ds, session) + ds.update_report('detection', session=None) + + # make cutouts of all the sources in the "detections" source list + SCLogger.info(f"cutter for image id {ds.image.id}") + ds = self.cutter.run(ds, session) + ds.update_report('cutting', session=None) + + # extract photometry and analytical cuts + SCLogger.info(f"measurer for image id {ds.image.id}") + ds = self.measurer.run(ds, session) + ds.update_report('measuring', session=None) + + # measure deep learning models on the cutouts/measurements + # TODO: add this... + + if self.pars.save_at_finish: + t_start = time.perf_counter() + try: + SCLogger.info(f"Saving final products for image id {ds.image.id}") + ds.save_and_commit(session=session) + except Exception as e: + ds.update_report('save final', session) + SCLogger.error(f"Failed to save final products for image id {ds.image.id}") + SCLogger.error(e) + raise e + + ds.runtimes['save_final'] = time.perf_counter() - t_start + + ds.finalize_report(session) + + return ds + + except Exception as e: + ds.catch_exception(e) + finally: + # make sure the DataStore is returned in case the calling scope want to debug the pipeline run return ds def run_with_session(self): @@ -356,7 +385,7 @@ def run_with_session(self): with SmartSession() as session: self.run(session=session) - def make_provenance_tree(self, exposure, reference=None, overrides=None, session=None, commit=True): + def make_provenance_tree(self, exposure, overrides=None, session=None, commit=True): """Use the current configuration of the pipeline and all the objects it has to generate the provenances for all the processing steps. This will conclude with the reporting step, which simply has an upstreams @@ -368,15 +397,6 @@ def make_provenance_tree(self, exposure, reference=None, overrides=None, session exposure : Exposure The exposure to use to get the initial provenance. This provenance should be automatically created by the exposure. - reference: str, Provenance object or None - Can be a string matching a valid reference set. This tells the pipeline which - provenance to load for the reference. - Instead, can provide either a Reference object with a Provenance - or the Provenance object of a reference directly. - If not given, will simply load the most recently created reference provenance. - # TODO: when we implement reference sets, we will probably not allow this input directly to - # this function anymore. Instead, you will need to define the reference set in the config, - # under the subtraction parameters. overrides: dict, optional A dictionary of provenances to override any of the steps in the pipeline. For example, set overrides={'preprocessing': prov} to use a specific provenance @@ -402,71 +422,61 @@ def make_provenance_tree(self, exposure, reference=None, overrides=None, session with SmartSession(session) as session: # start by getting the exposure and reference - # TODO: need a better way to find the relevant reference PROVENANCE for this exposure - # i.e., we do not look for a valid reference and get its provenance, instead, - # we look for a provenance based on our policy (that can be defined in the subtraction parameters) - # and find a specific provenance id that matches our policy. - # If we later find that no reference with that provenance exists that overlaps our images, - # that will be recorded as an error in the report. - # One way to do this would be to add a RefSet table that has a name (e.g., "standard") and - # a validity time range (which will be removed from Reference), maybe also the instrument. - # That would allow us to use a combination of name+obs_time to find a specific RefSet, - # which has a single reference provenance ID. If you want a custom reference, - # add a new RefSet with a new name. - # This also means that the reference making pipeline MUST use a single set of policies - # to create all the references for a given RefSet... we need to make sure we can actually - # make that happen consistently (e.g., if you change parameters or start mixing instruments - # when you make the references it will create multiple provenances for the same RefSet). - if isinstance(reference, str): - raise NotImplementedError('See issue #287') - elif isinstance(reference, Reference): - ref_prov = reference.provenance - elif isinstance(reference, Provenance): - ref_prov = reference - elif reference is None: # use the latest provenance that has to do with references - ref_prov = session.scalars( - sa.select(Provenance).where( - Provenance.process == 'reference' - ).order_by(Provenance.created_at.desc()) - ).first() - exp_prov = session.merge(exposure.provenance) # also merges the code_version provs = {'exposure': exp_prov} code_version = exp_prov.code_version is_testing = exp_prov.is_testing - for step in PROCESS_OBJECTS: - if step in overrides: + ref_provs = None # allow multiple reference provenances for each refset + refset_name = self.subtractor.pars.refset + # If refset is None, we will just fail to produce a subtraction, but everything else works... + # Note that the upstreams for the subtraction provenance will be wrong, because we don't have + # any reference provenances to link to. But this is what you get when putting refset=None. + # Just know that the "output provenance" (e.g., of the Measurements) will never actually exist, + # even though you can use it to make the Report provenance (just so you have something to refer to). + if refset_name is not None: + + refset = session.scalars(sa.select(RefSet).where(RefSet.name == refset_name)).first() + if refset is None: + raise ValueError(f'No reference set with name {refset_name} found in the database!') + + ref_provs = refset.provenances + if ref_provs is None or len(ref_provs) == 0: + raise ValueError(f'No provenances found for reference set {refset_name}!') + + provs['referencing'] = ref_provs # notice that this is a list, not a single provenance! + for step in PROCESS_OBJECTS: # produce the provenance for this step + if step in overrides: # accept override from user input provs[step] = overrides[step] - else: - obj_name = PROCESS_OBJECTS[step] - if isinstance(obj_name, dict): + else: # load the parameters from the objects on the pipeline + obj_name = PROCESS_OBJECTS[step] # translate the step to the object name + if isinstance(obj_name, dict): # sub-objects, e.g., extraction.sources, extraction.wcs, etc. # get the first item of the dictionary and hope its pars object has siblings defined correctly: obj_name = obj_name.get(list(obj_name.keys())[0]) parameters = getattr(self, obj_name).pars.get_critical_pars() - # some preprocessing parameters (the "preprocessing_steps") don't come from the - # config file, but instead come from the preprocessing itself. - # TODO: fix this as part of issue #147 - # if step == 'preprocessing': - # parameters['preprocessing_steps'] = ['overscan', 'linearity', 'flat', 'fringe'] - # figure out which provenances go into the upstreams for this step up_steps = UPSTREAM_STEPS[step] if isinstance(up_steps, str): up_steps = [up_steps] - upstreams = [] + upstream_provs = [] for upstream in up_steps: - if upstream == 'reference': - upstreams += ref_prov.upstreams - else: - upstreams.append(provs[upstream]) + if upstream == 'referencing': # this is an externally supplied provenance upstream + if ref_provs is not None: + # we never put the Reference object's provenance into the upstreams of the subtraction + # instead, put the provenances of the coadd image and its extraction products + # this is so the subtraction provenance has the (preprocessing+extraction) provenance + # for each one of its upstream_images (in this case, ref+new). + # by construction all references on the refset SHOULD have the same upstreams + upstream_provs += ref_provs[0].upstreams + else: # just grab the provenance of what is upstream of this step from the existing tree + upstream_provs.append(provs[upstream]) provs[step] = Provenance( code_version=code_version, process=step, parameters=parameters, - upstreams=upstreams, + upstreams=upstream_provs, is_testing=is_testing, ) @@ -476,3 +486,5 @@ def make_provenance_tree(self, exposure, reference=None, overrides=None, session session.commit() return provs + + diff --git a/tests/conftest.py b/tests/conftest.py index 2a63d44d..a3cd2bff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from models.object import Object from util.archive import Archive -from util.util import remove_empty_folders +from util.util import remove_empty_folders, env_as_bool from util.retrydownload import retry_download from util.logger import SCLogger @@ -34,10 +34,13 @@ 'tests.fixtures.ztf', 'tests.fixtures.ptf', 'tests.fixtures.pipeline_objects', + 'tests.fixtures.datastore_factory', ] ARCHIVE_PATH = None +SKIP_WARNING_TESTS = False + # We may want to turn this on only for tests, as it may add a lot of runtime/memory overhead # ref: https://www.mail-archive.com/python-list@python.org/msg443129.html # os.environ["SEECHANGE_TRACEMALLOC"] = "1" @@ -47,9 +50,11 @@ # (session is the pytest session, not the SQLAlchemy session) def pytest_sessionstart(session): # Will be executed before the first test + global SKIP_WARNING_TESTS - # this is only to make the warnings into errors, so it is easier to track them down... - # warnings.filterwarnings('error', append=True) # comment this out in regular usage + if False: # this is only to make the warnings into errors, so it is easier to track them down... + warnings.filterwarnings('error', append=True) # comment this out in regular usage + SKIP_WARNING_TESTS = True setup_warning_filters() # load the list of warnings that are to be ignored (not just in tests) # below are additional warnings that are ignored only during tests: @@ -63,6 +68,15 @@ def pytest_sessionstart(session): FileOnDiskMixin.configure_paths() # SCLogger.setLevel( logging.INFO ) + # get rid of any catalog excerpts from previous runs: + with SmartSession() as session: + catexps = session.scalars(sa.select(CatalogExcerpt)).all() + for catexp in catexps: + if os.path.isfile(catexp.get_fullpath()): + os.remove(catexp.get_fullpath()) + session.delete(catexp) + session.commit() + # This will be executed after the last test (session is the pytest session, not the SQLAlchemy session) def pytest_sessionfinish(session, exitstatus): @@ -83,12 +97,10 @@ def pytest_sessionfinish(session, exitstatus): if Class.__name__ in ['CodeVersion', 'CodeHash', 'SensorSection', 'CatalogExcerpt', 'Provenance', 'Object']: SCLogger.debug(f'There are {len(ids)} {Class.__name__} objects in the database. These are OK to stay.') elif len(ids) > 0: - SCLogger.info( - f'There are {len(ids)} {Class.__name__} objects in the database. Please make sure to cleanup!' - ) + print(f'There are {len(ids)} {Class.__name__} objects in the database. Please make sure to cleanup!') for id in ids: obj = dbsession.scalars(sa.select(Class).where(Class.id == id)).first() - SCLogger.info(f' {obj}') + print(f' {obj}') any_objects = True # delete the CodeVersion object (this should remove all provenances as well) @@ -179,13 +191,13 @@ def blocking_plots(): - It is set to a True value: make the plots, but stop the test execution until the figure is closed. If a test only makes plots and does not test functionality, it should be marked with - @pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) + @pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) If a test makes a diagnostic plot, that is only ever used to visually inspect the results, then it should be surrounded by an if blocking_plots: statement. It will only run in interactive mode. If a test makes a plot that should be saved to disk, it should either have the skipif mentioned above, - or have an if os.getenv('INTERACTIVE'): statement surrounding the plot itself. + or have an if env_as_bool('INTERACTIVE'): statement surrounding the plot itself. You may want to add plt.show(block=blocking_plots) to allow the figure to stick around in interactive mode, on top of saving the figure at the end of the test. """ @@ -196,7 +208,7 @@ def blocking_plots(): if not os.path.isdir(os.path.join(CODE_ROOT, 'tests/plots')): os.makedirs(os.path.join(CODE_ROOT, 'tests/plots')) - inter = os.getenv('INTERACTIVE', False) + inter = env_as_bool('INTERACTIVE') if isinstance(inter, str): inter = inter.lower() in ('true', '1', 'on', 'yes') diff --git a/tests/fixtures/datastore_factory.py b/tests/fixtures/datastore_factory.py new file mode 100644 index 00000000..fef12f45 --- /dev/null +++ b/tests/fixtures/datastore_factory.py @@ -0,0 +1,605 @@ +import os +import warnings +import shutil +import pytest + +import numpy as np + +import sqlalchemy as sa + +from models.base import SmartSession +from models.provenance import Provenance +from models.enums_and_bitflags import BitFlagConverter +from models.image import Image +from models.source_list import SourceList +from models.psf import PSF +from models.background import Background +from models.world_coordinates import WorldCoordinates +from models.zero_point import ZeroPoint +from models.cutouts import Cutouts +from models.measurements import Measurements +from models.refset import RefSet +from pipeline.data_store import DataStore + +from util.logger import SCLogger +from util.cache import copy_to_cache, copy_list_to_cache, copy_from_cache, copy_list_from_cache +from util.util import env_as_bool + +from improc.bitmask_tools import make_saturated_flag + + +@pytest.fixture(scope='session') +def datastore_factory(data_dir, pipeline_factory, request): + """Provide a function that returns a datastore with all the products based on the given exposure and section ID. + + To use this data store in a test where new data is to be generated, + simply change the pipeline object's "test_parameter" value to a unique + new value, so the provenance will not match and the data will be regenerated. + + If "save_original_image" is True, then a copy of the image before + going through source extraction, WCS, etc. will be saved alongside + the image, with ".image.fits.original" appended to the filename; + this path will be in ds.path_to_original_image. In this case, the + thing that calls this factory must delete that file when done. + + EXAMPLE + ------- + extractor.pars.test_parameter = uuid.uuid().hex + extractor.run(datastore) + assert extractor.has_recalculated is True + + """ + def make_datastore( + *args, + cache_dir=None, + cache_base_name=None, + session=None, + overrides={}, + augments={}, + bad_pixel_map=None, + save_original_image=False + ): + code_version = args[0].provenance.code_version + ds = DataStore(*args) # make a new datastore + use_cache = cache_dir is not None and cache_base_name is not None and not env_as_bool( "LIMIT_CACHE_USAGE" ) + + if cache_base_name is not None: + cache_name = cache_base_name + '.image.fits.json' + image_cache_path = os.path.join(cache_dir, cache_name) + else: + image_cache_path = None + + if use_cache: + ds.cache_base_name = os.path.join(cache_dir, cache_base_name) # save this for testing purposes + + p = pipeline_factory() + + # allow calling scope to override/augment parameters for any of the processing steps + p.override_parameters(**overrides) + p.augment_parameters(**augments) + + with SmartSession(session) as session: + code_version = session.merge(code_version) + + if ds.image is not None: # if starting from an externally provided Image, must merge it first + ds.image = ds.image.merge_all(session) + + ############ load the reference set ############ + + inst_name = ds.image.instrument.lower() if ds.image else ds.exposure.instrument.lower() + refset_name = f'test_refset_{inst_name}' + if inst_name == 'ptf': # request the ptf_refset fixture dynamically: + request.getfixturevalue('ptf_refset') + if inst_name == 'decam': # request the decam_refset fixture dynamically: + request.getfixturevalue('decam_refset') + + refset = session.scalars(sa.select(RefSet).where(RefSet.name == refset_name)).first() + + if refset is None: + raise ValueError(f'No reference set found with name {refset_name}') + + ref_prov = refset.provenances[0] + + ############ preprocessing to create image ############ + if ds.image is None and use_cache: # check if preprocessed image is in cache + if os.path.isfile(image_cache_path): + SCLogger.debug('loading image from cache. ') + ds.image = copy_from_cache(Image, cache_dir, cache_name) + # assign the correct exposure to the object loaded from cache + if ds.exposure_id is not None: + ds.image.exposure_id = ds.exposure_id + if ds.exposure is not None: + ds.image.exposure = ds.exposure + ds.image.exposure_id = ds.exposure.id + + # Copy the original image from the cache if requested + if save_original_image: + ds.path_to_original_image = ds.image.get_fullpath()[0] + '.image.fits.original' + image_cache_path_original = os.path.join(cache_dir, ds.image.filepath + '.image.fits.original') + shutil.copy2( image_cache_path_original, ds.path_to_original_image ) + + upstreams = [ds.exposure.provenance] if ds.exposure is not None else [] # images without exposure + prov = Provenance( + code_version=code_version, + process='preprocessing', + upstreams=upstreams, + parameters=p.preprocessor.pars.get_critical_pars(), + is_testing=True, + ) + prov = session.merge(prov) + session.commit() + + # if Image already exists on the database, use that instead of this one + existing = session.scalars(sa.select(Image).where(Image.filepath == ds.image.filepath)).first() + if existing is not None: + # overwrite the existing row data using the JSON cache file + for key in sa.inspect(ds.image).mapper.columns.keys(): + value = getattr(ds.image, key) + if ( + key not in ['id', 'image_id', 'created_at', 'modified'] and + value is not None + ): + setattr(existing, key, value) + ds.image = existing # replace with the existing row + + ds.image.provenance = prov + + # make sure this is saved to the archive as well + ds.image.save(verify_md5=False) + + if ds.image is None: # make the preprocessed image + SCLogger.debug('making preprocessed image. ') + ds = p.preprocessor.run(ds, session) + ds.image.provenance.is_testing = True + if bad_pixel_map is not None: + ds.image.flags |= bad_pixel_map + if ds.image.weight is not None: + ds.image.weight[ds.image.flags.astype(bool)] = 0.0 + + # flag saturated pixels, too (TODO: is there a better way to get the saturation limit? ) + mask = make_saturated_flag(ds.image.data, ds.image.instrument_object.saturation_limit, iterations=2) + ds.image.flags |= (mask * 2 ** BitFlagConverter.convert('saturated')).astype(np.uint16) + + ds.image.save() + # even if cache_base_name is None, we still need to make the manifest file, so we will get it next time! + if not env_as_bool( "LIMIT_CACHE_USAGE" ) and os.path.isdir(cache_dir): + output_path = copy_to_cache(ds.image, cache_dir) + + if image_cache_path is not None and output_path != image_cache_path: + warnings.warn(f'cache path {image_cache_path} does not match output path {output_path}') + else: + cache_base_name = output_path[:-10] # remove the '.image.fits' part + ds.cache_base_name = output_path + SCLogger.debug(f'Saving image to cache at: {output_path}') + use_cache = True # the two other conditions are true to even get to this part... + + # In test_astro_cal, there's a routine that needs the original + # image before being processed through the rest of what this + # factory function does, so save it if requested + if save_original_image: + ds.path_to_original_image = ds.image.get_fullpath()[0] + '.image.fits.original' + shutil.copy2( ds.image.get_fullpath()[0], ds.path_to_original_image ) + if use_cache: + shutil.copy2( + ds.image.get_fullpath()[0], + os.path.join(cache_dir, ds.image.filepath + '.image.fits.original') + ) + + ############# extraction to create sources / PSF / BG / WCS / ZP ############# + if use_cache: # try to get the SourceList, PSF, BG, WCS and ZP from cache + prov = Provenance( + code_version=code_version, + process='extraction', + upstreams=[ds.image.provenance], + parameters=p.extractor.pars.get_critical_pars(), # the siblings will be loaded automatically + is_testing=True, + ) + prov = session.merge(prov) + session.commit() + + # try to get the source list from cache + cache_name = f'{cache_base_name}.sources_{prov.id[:6]}.fits.json' + sources_cache_path = os.path.join(cache_dir, cache_name) + if os.path.isfile(sources_cache_path): + SCLogger.debug('loading source list from cache. ') + ds.sources = copy_from_cache(SourceList, cache_dir, cache_name) + + # if SourceList already exists on the database, use that instead of this one + existing = session.scalars( + sa.select(SourceList).where(SourceList.filepath == ds.sources.filepath) + ).first() + if existing is not None: + # overwrite the existing row data using the JSON cache file + for key in sa.inspect(ds.sources).mapper.columns.keys(): + value = getattr(ds.sources, key) + if ( + key not in ['id', 'image_id', 'created_at', 'modified'] and + value is not None + ): + setattr(existing, key, value) + ds.sources = existing # replace with the existing row + + ds.sources.provenance = prov + ds.sources.image = ds.image + + # make sure this is saved to the archive as well + ds.sources.save(verify_md5=False) + + # try to get the PSF from cache + cache_name = f'{cache_base_name}.psf_{prov.id[:6]}.fits.json' + psf_cache_path = os.path.join(cache_dir, cache_name) + if os.path.isfile(psf_cache_path): + SCLogger.debug('loading PSF from cache. ') + ds.psf = copy_from_cache(PSF, cache_dir, cache_name) + + # if PSF already exists on the database, use that instead of this one + existing = session.scalars( + sa.select(PSF).where(PSF.filepath == ds.psf.filepath) + ).first() + if existing is not None: + # overwrite the existing row data using the JSON cache file + for key in sa.inspect(ds.psf).mapper.columns.keys(): + value = getattr(ds.psf, key) + if ( + key not in ['id', 'image_id', 'created_at', 'modified'] and + value is not None + ): + setattr(existing, key, value) + ds.psf = existing # replace with the existing row + + ds.psf.provenance = prov + ds.psf.image = ds.image + + # make sure this is saved to the archive as well + ds.psf.save(verify_md5=False, overwrite=True) + + # try to get the background from cache + cache_name = f'{cache_base_name}.bg_{prov.id[:6]}.h5.json' + bg_cache_path = os.path.join(cache_dir, cache_name) + if os.path.isfile(bg_cache_path): + SCLogger.debug('loading background from cache. ') + ds.bg = copy_from_cache(Background, cache_dir, cache_name) + + # if BG already exists on the database, use that instead of this one + existing = session.scalars( + sa.select(Background).where(Background.filepath == ds.bg.filepath) + ).first() + if existing is not None: + # overwrite the existing row data using the JSON cache file + for key in sa.inspect(ds.bg).mapper.columns.keys(): + value = getattr(ds.bg, key) + if ( + key not in ['id', 'image_id', 'created_at', 'modified'] and + value is not None + ): + setattr(existing, key, value) + ds.bg = existing + + ds.bg.provenance = prov + ds.bg.image = ds.image + + # make sure this is saved to the archive as well + ds.bg.save(verify_md5=False, overwrite=True) + + # try to get the WCS from cache + cache_name = f'{cache_base_name}.wcs_{prov.id[:6]}.txt.json' + wcs_cache_path = os.path.join(cache_dir, cache_name) + if os.path.isfile(wcs_cache_path): + SCLogger.debug('loading WCS from cache. ') + ds.wcs = copy_from_cache(WorldCoordinates, cache_dir, cache_name) + prov = session.merge(prov) + + # check if WCS already exists on the database + if ds.sources is not None: + existing = session.scalars( + sa.select(WorldCoordinates).where( + WorldCoordinates.sources_id == ds.sources.id, + WorldCoordinates.provenance_id == prov.id + ) + ).first() + else: + existing = None + + if existing is not None: + # overwrite the existing row data using the JSON cache file + for key in sa.inspect(ds.wcs).mapper.columns.keys(): + value = getattr(ds.wcs, key) + if ( + key not in ['id', 'sources_id', 'created_at', 'modified'] and + value is not None + ): + setattr(existing, key, value) + ds.wcs = existing # replace with the existing row + + ds.wcs.provenance = prov + ds.wcs.sources = ds.sources + # make sure this is saved to the archive as well + ds.wcs.save(verify_md5=False, overwrite=True) + + # try to get the ZP from cache + cache_name = cache_base_name + '.zp.json' + zp_cache_path = os.path.join(cache_dir, cache_name) + if os.path.isfile(zp_cache_path): + SCLogger.debug('loading zero point from cache. ') + ds.zp = copy_from_cache(ZeroPoint, cache_dir, cache_name) + + # check if ZP already exists on the database + if ds.sources is not None: + existing = session.scalars( + sa.select(ZeroPoint).where( + ZeroPoint.sources_id == ds.sources.id, + ZeroPoint.provenance_id == prov.id + ) + ).first() + else: + existing = None + + if existing is not None: + # overwrite the existing row data using the JSON cache file + for key in sa.inspect(ds.zp).mapper.columns.keys(): + value = getattr(ds.zp, key) + if ( + key not in ['id', 'sources_id', 'created_at', 'modified'] and + value is not None + ): + setattr(existing, key, value) + ds.zp = existing # replace with the existing row + + ds.zp.provenance = prov + ds.zp.sources = ds.sources + + # if any data product is missing, must redo the extraction step + if ds.sources is None or ds.psf is None or ds.bg is None or ds.wcs is None or ds.zp is None: + SCLogger.debug('extracting sources. ') + ds = p.extractor.run(ds, session) + + ds.sources.save(overwrite=True) + if use_cache: + output_path = copy_to_cache(ds.sources, cache_dir) + if output_path != sources_cache_path: + warnings.warn(f'cache path {sources_cache_path} does not match output path {output_path}') + + ds.psf.save(overwrite=True) + if use_cache: + output_path = copy_to_cache(ds.psf, cache_dir) + if output_path != psf_cache_path: + warnings.warn(f'cache path {psf_cache_path} does not match output path {output_path}') + + SCLogger.debug('Running background estimation') + ds = p.backgrounder.run(ds, session) + + ds.bg.save(overwrite=True) + if use_cache: + output_path = copy_to_cache(ds.bg, cache_dir) + if output_path != bg_cache_path: + warnings.warn(f'cache path {bg_cache_path} does not match output path {output_path}') + + SCLogger.debug('Running astrometric calibration') + ds = p.astrometor.run(ds, session) + ds.wcs.save(overwrite=True) + if use_cache: + output_path = copy_to_cache(ds.wcs, cache_dir) + if output_path != wcs_cache_path: + warnings.warn(f'cache path {wcs_cache_path} does not match output path {output_path}') + + SCLogger.debug('Running photometric calibration') + ds = p.photometor.run(ds, session) + if use_cache: + cache_name = cache_base_name + '.zp.json' + output_path = copy_to_cache(ds.zp, cache_dir, cache_name) + if output_path != zp_cache_path: + warnings.warn(f'cache path {zp_cache_path} does not match output path {output_path}') + + ds.save_and_commit(session=session) + + # make a new copy of the image to cache, including the estimates for lim_mag, fwhm, etc. + if not env_as_bool("LIMIT_CACHE_USAGE"): + output_path = copy_to_cache(ds.image, cache_dir) + + # must provide the reference provenance explicitly since we didn't build a prov_tree + ref = ds.get_reference(ref_prov, session=session) + if ref is None: + return ds # if no reference is found, simply return the datastore without the rest of the products + + if use_cache: # try to find the subtraction image in the cache + prov = Provenance( + code_version=code_version, + process='subtraction', + upstreams=[ + ds.image.provenance, + ds.sources.provenance, + ref.image.provenance, + ref.sources.provenance, + ], + parameters=p.subtractor.pars.get_critical_pars(), + is_testing=True, + ) + prov = session.merge(prov) + session.commit() + + sub_im = Image.from_new_and_ref(ds.image, ref.image) + sub_im.provenance = prov + cache_sub_name = sub_im.invent_filepath() + cache_name = cache_sub_name + '.image.fits.json' + sub_cache_path = os.path.join(cache_dir, cache_name) + if os.path.isfile(sub_cache_path): + SCLogger.debug('loading subtraction image from cache. ') + ds.sub_image = copy_from_cache(Image, cache_dir, cache_name) + + ds.sub_image.provenance = prov + ds.sub_image.upstream_images.append(ref.image) + ds.sub_image.ref_image_id = ref.image_id + ds.sub_image.ref_image = ref.image + ds.sub_image.new_image = ds.image + ds.sub_image.save(verify_md5=False) # make sure it is also saved to archive + + # try to load the aligned images from cache + prov_aligned_ref = Provenance( + code_version=code_version, + parameters=prov.parameters['alignment'], + upstreams=[ + ds.image.provenance, + ds.sources.provenance, # this also includes the PSF's provenance + ds.wcs.provenance, + ds.ref_image.provenance, + ds.ref_image.sources.provenance, + ds.ref_image.wcs.provenance, + ds.ref_image.zp.provenance, + ], + process='alignment', + is_testing=True, + ) + # TODO: can we find a less "hacky" way to do this? + f = ref.image.invent_filepath() + f = f.replace('ComSci', 'Warped') # not sure if this or 'Sci' will be in the filename + f = f.replace('Sci', 'Warped') # in any case, replace it with 'Warped' + f = f[:-6] + prov_aligned_ref.id[:6] # replace the provenance ID + filename_aligned_ref = f + + prov_aligned_new = Provenance( + code_version=code_version, + parameters=prov.parameters['alignment'], + upstreams=[ + ds.image.provenance, + ds.sources.provenance, # this also includes provs for PSF, BG, WCS, ZP + ], + process='alignment', + is_testing=True, + ) + f = ds.sub_image.new_image.invent_filepath() + f = f.replace('ComSci', 'Warped') + f = f.replace('Sci', 'Warped') + f = f[:-6] + prov_aligned_new.id[:6] + filename_aligned_new = f + + cache_name_ref = filename_aligned_ref + '.image.fits.json' + cache_name_new = filename_aligned_new + '.image.fits.json' + if ( + os.path.isfile(os.path.join(cache_dir, cache_name_ref)) and + os.path.isfile(os.path.join(cache_dir, cache_name_new)) + ): + SCLogger.debug('loading aligned reference image from cache. ') + image_aligned_ref = copy_from_cache(Image, cache_dir, cache_name) + image_aligned_ref.provenance = prov_aligned_ref + image_aligned_ref.info['original_image_id'] = ds.ref_image.id + image_aligned_ref.info['original_image_filepath'] = ds.ref_image.filepath + image_aligned_ref.info['alignment_parameters'] = prov.parameters['alignment'] + image_aligned_ref.save(verify_md5=False, no_archive=True) + # TODO: should we also load the aligned image's sources, PSF, and ZP? + + SCLogger.debug('loading aligned new image from cache. ') + image_aligned_new = copy_from_cache(Image, cache_dir, cache_name) + image_aligned_new.provenance = prov_aligned_new + image_aligned_new.info['original_image_id'] = ds.image_id + image_aligned_new.info['original_image_filepath'] = ds.image.filepath + image_aligned_new.info['alignment_parameters'] = prov.parameters['alignment'] + image_aligned_new.save(verify_md5=False, no_archive=True) + # TODO: should we also load the aligned image's sources, PSF, and ZP? + + if image_aligned_ref.mjd < image_aligned_new.mjd: + ds.sub_image._aligned_images = [image_aligned_ref, image_aligned_new] + else: + ds.sub_image._aligned_images = [image_aligned_new, image_aligned_ref] + + if ds.sub_image is None: # no hit in the cache + ds = p.subtractor.run(ds, session) + ds.sub_image.save(verify_md5=False) # make sure it is also saved to archive + if use_cache: + output_path = copy_to_cache(ds.sub_image, cache_dir) + if output_path != sub_cache_path: + warnings.warn(f'cache path {sub_cache_path} does not match output path {output_path}') + + if use_cache: # save the aligned images to cache + for im in ds.sub_image.aligned_images: + im.save(no_archive=True) + copy_to_cache(im, cache_dir) + + ############ detecting to create a source list ############ + prov = Provenance( + code_version=code_version, + process='detection', + upstreams=[ds.sub_image.provenance], + parameters=p.detector.pars.get_critical_pars(), + is_testing=True, + ) + prov = session.merge(prov) + session.commit() + + cache_name = os.path.join(cache_dir, cache_sub_name + f'.sources_{prov.id[:6]}.npy.json') + if use_cache and os.path.isfile(cache_name): + SCLogger.debug('loading detections from cache. ') + ds.detections = copy_from_cache(SourceList, cache_dir, cache_name) + ds.detections.provenance = prov + ds.detections.image = ds.sub_image + ds.sub_image.sources = ds.detections + ds.detections.save(verify_md5=False) + else: # cannot find detections on cache + ds = p.detector.run(ds, session) + ds.detections.save(verify_md5=False) + if use_cache: + copy_to_cache(ds.detections, cache_dir, cache_name) + + ############ cutting to create cutouts ############ + prov = Provenance( + code_version=code_version, + process='cutting', + upstreams=[ds.detections.provenance], + parameters=p.cutter.pars.get_critical_pars(), + is_testing=True, + ) + prov = session.merge(prov) + session.commit() + + cache_name = os.path.join(cache_dir, cache_sub_name + f'.cutouts_{prov.id[:6]}.h5') + if use_cache and ( os.path.isfile(cache_name) ): + SCLogger.debug('loading cutouts from cache. ') + ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name) + ds.cutouts.provenance = prov + ds.cutouts.sources = ds.detections + ds.cutouts.load_all_co_data() # sources must be set first + ds.cutouts.save() # make sure to save to archive as well + else: # cannot find cutouts on cache + ds = p.cutter.run(ds, session) + ds.cutouts.save() + if use_cache: + copy_to_cache(ds.cutouts, cache_dir) + + ############ measuring to create measurements ############ + prov = Provenance( + code_version=code_version, + process='measuring', + upstreams=[ds.cutouts.provenance], + parameters=p.measurer.pars.get_critical_pars(), + is_testing=True, + ) + prov = session.merge(prov) + session.commit() + + cache_name = os.path.join(cache_dir, cache_sub_name + f'.measurements_{prov.id[:6]}.json') + + if use_cache and ( os.path.isfile(cache_name) ): + # note that the cache contains ALL the measurements, not only the good ones + SCLogger.debug('loading measurements from cache. ') + ds.all_measurements = copy_list_from_cache(Measurements, cache_dir, cache_name) + [setattr(m, 'provenance', prov) for m in ds.all_measurements] + [setattr(m, 'cutouts', ds.cutouts) for m in ds.all_measurements] + + ds.measurements = [] + for m in ds.all_measurements: + threshold_comparison = p.measurer.compare_measurement_to_thresholds(m) + if threshold_comparison != "delete": # all disqualifiers are below threshold + m.is_bad = threshold_comparison == "bad" + ds.measurements.append(m) + + [m.associate_object(session) for m in ds.measurements] # create or find an object for each measurement + # no need to save list because Measurements is not a FileOnDiskMixin! + else: # cannot find measurements on cache + ds = p.measurer.run(ds, session) + if use_cache: + copy_list_to_cache(ds.all_measurements, cache_dir, cache_name) # must provide filepath! + + ds.save_and_commit(session=session) + + return ds + + return make_datastore diff --git a/tests/fixtures/decam.py b/tests/fixtures/decam.py index c9925e73..89418fbc 100644 --- a/tests/fixtures/decam.py +++ b/tests/fixtures/decam.py @@ -3,7 +3,6 @@ import wget import yaml import shutil -import warnings import sqlalchemy as sa import numpy as np @@ -17,7 +16,6 @@ from models.provenance import Provenance from models.exposure import Exposure from models.image import Image -from models.source_list import SourceList from models.datafile import DataFile from models.reference import Reference @@ -25,16 +23,8 @@ from util.retrydownload import retry_download from util.logger import SCLogger -from util.cache import copy_to_cache, copy_list_to_cache, copy_from_cache, copy_list_from_cache - - -@pytest.fixture(scope='session') -def decam_cache_dir(cache_dir): - output = os.path.join(cache_dir, 'DECam') - if not os.path.isdir(output): - os.makedirs(output) - - yield output +from util.cache import copy_to_cache, copy_from_cache +from util.util import env_as_bool @pytest.fixture(scope='session') @@ -55,7 +45,7 @@ def decam_cache_dir(cache_dir): def decam_default_calibrators(cache_dir, data_dir): try: # try to get the calibrators from the cache folder - if not os.getenv( "LIMIT_CACHE_USAGE" ): + if not env_as_bool( "LIMIT_CACHE_USAGE" ): if os.path.isdir(os.path.join(cache_dir, 'DECam_default_calibrators')): shutil.copytree( os.path.join(cache_dir, 'DECam_default_calibrators'), @@ -73,7 +63,7 @@ def decam_default_calibrators(cache_dir, data_dir): decam._get_default_calibrator( 60000, sec, calibtype='linearity' ) # store the calibration files in the cache folder - if not os.getenv( "LIMIT_CACHE_USAGE" ): + if not env_as_bool( "LIMIT_CACHE_USAGE" ): if not os.path.isdir(os.path.join(cache_dir, 'DECam_default_calibrators')): os.makedirs(os.path.join(cache_dir, 'DECam_default_calibrators'), exist_ok=True) for folder in os.listdir(os.path.join(data_dir, 'DECam_default_calibrators')): @@ -188,7 +178,7 @@ def decam_filename(download_url, data_dir, decam_cache_dir): url = os.path.join(download_url, 'DECAM', base_name) if not os.path.isfile(filename): - if os.getenv( "LIMIT_CACHE_USAGE" ): + if env_as_bool( "LIMIT_CACHE_USAGE" ): wget.download( url=url, out=filename ) else: cachedfilename = os.path.join(decam_cache_dir, base_name) @@ -270,7 +260,8 @@ def decam_datastore( 'N1', cache_dir=decam_cache_dir, cache_base_name='115/c4d_20221104_074232_N1_g_Sci_NBXRIO', - save_original_image=True + overrides={'subtraction': {'refset': 'test_refset_decam'}}, + save_original_image=True, ) # This save is redundant, as the datastore_factory calls save_and_commit # However, I leave this here because it is a good test that calling it twice @@ -322,7 +313,7 @@ def decam_fits_image_filename(download_url, decam_cache_dir): yield filename - if os.getenv( "LIMIT_CACHE_USAGE" ): + if env_as_bool( "LIMIT_CACHE_USAGE" ): try: os.unlink( filepath ) except FileNotFoundError: @@ -341,7 +332,7 @@ def decam_fits_image_filename2(download_url, decam_cache_dir): yield filename - if os.getenv( "LIMIT_CACHE_USAGE" ): + if env_as_bool( "LIMIT_CACHE_USAGE" ): try: os.unlink( filepath ) except FileNotFoundError: @@ -349,8 +340,9 @@ def decam_fits_image_filename2(download_url, decam_cache_dir): @pytest.fixture -def decam_ref_datastore( code_version, download_url, decam_cache_dir, data_dir, datastore_factory ): +def decam_ref_datastore(code_version, download_url, decam_cache_dir, data_dir, datastore_factory, refmaker_factory): filebase = 'DECaPS-West_20220112.g.32' + maker = refmaker_factory('test_refset_decam', 'DECam') # I added this mirror so the tests will pass, and we should remove it once the decam image goes back up to NERSC # TODO: should we leave these as a mirror in case NERSC is down? @@ -362,7 +354,6 @@ def decam_ref_datastore( code_version, download_url, decam_cache_dir, data_dir, for ext in [ '.image.fits', '.weight.fits', '.flags.fits', '.image.yaml' ]: cache_path = os.path.join(decam_cache_dir, f'115/{filebase}{ext}') - fzpath = cache_path + '.fz' if os.path.isfile(cache_path): SCLogger.info( f"{cache_path} exists, not redownloading." ) else: # need to download! @@ -374,55 +365,51 @@ def decam_ref_datastore( code_version, download_url, decam_cache_dir, data_dir, if not ext.endswith('.yaml'): destination = os.path.join(data_dir, f'115/{filebase}{ext}') os.makedirs(os.path.dirname(destination), exist_ok=True) - if os.getenv( "LIMIT_CACHE_USAGE" ): + if env_as_bool( "LIMIT_CACHE_USAGE" ): # move it out of cache into the data directory shutil.move( cache_path, destination ) - else: + else: # copy but leave it in the cache for re-use shutil.copy2( cache_path, destination ) - yaml_path = os.path.join(decam_cache_dir, f'115/{filebase}.image.yaml') - - with open( yaml_path ) as ifp: - refyaml = yaml.safe_load( ifp ) - with SmartSession() as session: + maker.make_refset(session=session) code_version = session.merge(code_version) - prov = Provenance( - process='preprocessing', - code_version=code_version, - parameters={}, - upstreams=[], - is_testing=True, - ) - # check if this Image is already in the DB - existing = session.scalars( - sa.select(Image).where(Image.filepath == f'115/{filebase}') - ).first() - if existing is None: - image = Image(**refyaml) - else: - # overwrite the existing row data using the YAML - for key, value in refyaml.items(): - if ( - key not in ['id', 'image_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - image = existing # replace with the existing object + # prov = Provenance( + # process='preprocessing', + # code_version=code_version, + # parameters={}, + # upstreams=[], + # is_testing=True, + # ) + prov = maker.coadd_im_prov + + # the JSON file is generated by our cache system, not downloaded from the NERSC archive + json_path = os.path.join( decam_cache_dir, f'115/{filebase}.image.fits.json' ) + if not env_as_bool( "LIMIT_CACHE_USAGE" ) and os.path.isfile( json_path ): + image = copy_from_cache(Image, decam_cache_dir, json_path) + image.provenance = prov + image.save(verify_md5=False) # make sure to upload to archive as well + else: # no cache, must create a new image object + yaml_path = os.path.join(decam_cache_dir, f'115/{filebase}.image.yaml') + + with open( yaml_path ) as ifp: + refyaml = yaml.safe_load( ifp ) - image.provenance = prov - image.filepath = f'115/{filebase}' - image.is_coadd = True - image.save(verify_md5=False) # make sure to upload to archive as well + image = Image(**refyaml) + image.provenance = prov + image.filepath = f'115/{filebase}' + image.is_coadd = True + image.save() # make sure to upload to archive as well - if not os.getenv( "LIMIT_CACHE_USAGE" ): - copy_to_cache( image, decam_cache_dir ) + if not env_as_bool( "LIMIT_CACHE_USAGE" ): # save a copy of the image in the cache + copy_to_cache( image, decam_cache_dir ) - ds = datastore_factory(image, cache_dir=decam_cache_dir, cache_base_name=f'115/{filebase}') + # the datastore factory will load from cache or recreate all the other products + ds = datastore_factory(image, cache_dir=decam_cache_dir, cache_base_name=f'115/{filebase}') - for filename in image.get_fullpath(as_list=True): - assert os.path.isfile(filename) + for filename in image.get_fullpath(as_list=True): + assert os.path.isfile(filename) - ds.save_and_commit(session) + ds.save_and_commit(session) delete_list = [ ds.image, ds.sources, ds.psf, ds.wcs, ds.zp, ds.sub_image, ds.detections, ds.cutouts, ds.measurements @@ -442,26 +429,27 @@ def decam_ref_datastore( code_version, download_url, decam_cache_dir, data_dir, @pytest.fixture -def decam_reference(decam_ref_datastore): +def decam_reference(decam_ref_datastore, refmaker_factory): + maker = refmaker_factory('test_refset_decam', 'DECam') ds = decam_ref_datastore with SmartSession() as session: - prov = Provenance( - code_version=ds.image.provenance.code_version, - process='reference', - parameters={'test_parameter': 'test_value'}, - upstreams=[ - ds.image.provenance, - ds.sources.provenance, - ], - is_testing=True, - ) + maker.make_refset(session=session) + # prov = Provenance( + # code_version=ds.image.provenance.code_version, + # process='referencing', + # parameters=maker.pars.get_critical_pars(), + # upstreams=[ + # ds.image.provenance, + # ds.sources.provenance, + # ], + # is_testing=True, + # ) + prov = maker.refset.provenances[0] prov = session.merge(prov) ref = Reference() ref.image = ds.image ref.provenance = prov - ref.validity_start = Time(55000, format='mjd', scale='tai').isot - ref.validity_end = Time(65000, format='mjd', scale='tai').isot ref.section_id = ds.image.section_id ref.filter = ds.image.filter ref.target = ds.image.target @@ -478,10 +466,32 @@ def decam_reference(decam_ref_datastore): with SmartSession() as session: ref = session.merge(ref) if sa.inspect(ref).persistent: - session.delete(ref.provenance) # should also delete the reference image + session.delete(ref) session.commit() +@pytest.fixture(scope='session') +def decam_refset(refmaker_factory): + refmaker = refmaker_factory('test_refset_decam', 'DECam') + refmaker.pars.save_new_refs = True + + refmaker.make_refset() + + yield refmaker.refset + + # delete all the references and the refset + with SmartSession() as session: + refmaker.refset = session.merge(refmaker.refset) + for prov in refmaker.refset.provenances: + refs = session.scalars(sa.select(Reference).where(Reference.provenance_id == prov.id)).all() + for ref in refs: + session.delete(ref) + + session.delete(refmaker.refset) + + session.commit() + + @pytest.fixture def decam_subtraction(decam_datastore): return decam_datastore.sub_image diff --git a/tests/fixtures/pipeline_objects.py b/tests/fixtures/pipeline_objects.py index ef12a7a7..461bf150 100644 --- a/tests/fixtures/pipeline_objects.py +++ b/tests/fixtures/pipeline_objects.py @@ -1,27 +1,5 @@ -import os -import warnings -import shutil import pytest -import numpy as np - -import sqlalchemy as sa - -import sep - -from models.base import SmartSession, FileOnDiskMixin -from models.provenance import Provenance -from models.enums_and_bitflags import BitFlagConverter -from models.image import Image -from models.source_list import SourceList -from models.psf import PSF -from models.background import Background -from models.world_coordinates import WorldCoordinates -from models.zero_point import ZeroPoint -from models.cutouts import Cutouts -from models.measurements import Measurements - -from pipeline.data_store import DataStore from pipeline.preprocessing import Preprocessor from pipeline.detection import Detector from pipeline.backgrounding import Backgrounder @@ -32,11 +10,7 @@ from pipeline.cutting import Cutter from pipeline.measuring import Measurer from pipeline.top_level import Pipeline - -from util.logger import SCLogger -from util.cache import copy_to_cache, copy_list_to_cache, copy_from_cache, copy_list_from_cache - -from improc.bitmask_tools import make_saturated_flag +from pipeline.ref_maker import RefMaker @pytest.fixture(scope='session') @@ -265,6 +239,8 @@ def pipeline_factory( ): def make_pipeline(): p = Pipeline(**test_config.value('pipeline')) + p.pars.save_before_subtraction = False + p.pars.save_at_finish = False p.preprocessor = preprocessor_factory() p.extractor = extractor_factory() p.backgrounder = backgrounder_factory() @@ -302,6 +278,7 @@ def pipeline_for_tests(pipeline_factory): def coadd_pipeline_factory( coadder_factory, extractor_factory, + backgrounder_factory, astrometor_factory, photometor_factory, test_config, @@ -310,6 +287,7 @@ def make_pipeline(): p = CoaddPipeline(**test_config.value('pipeline')) p.coadder = coadder_factory() p.extractor = extractor_factory() + p.backgrounder = backgrounder_factory() p.astrometor = astrometor_factory() p.photometor = photometor_factory() @@ -336,616 +314,20 @@ def coadd_pipeline_for_tests(coadd_pipeline_factory): @pytest.fixture(scope='session') -def datastore_factory(data_dir, pipeline_factory): - """Provide a function that returns a datastore with all the products based on the given exposure and section ID. - - To use this data store in a test where new data is to be generated, - simply change the pipeline object's "test_parameter" value to a unique - new value, so the provenance will not match and the data will be regenerated. - - If "save_original_image" is True, then a copy of the image before - going through source extraction, WCS, etc. will be saved along side - the image, with ".image.fits.original" appended to the filename; - this path will be in ds.path_to_original_image. In this case, the - thing that calls this factory must delete that file when done. - - EXAMPLE - ------- - extractor.pars.test_parameter = uuid.uuid().hex - extractor.run(datastore) - assert extractor.has_recalculated is True - - """ - def make_datastore( - *args, - cache_dir=None, - cache_base_name=None, - session=None, - overrides={}, - augments={}, - bad_pixel_map=None, - save_original_image=False - ): - code_version = args[0].provenance.code_version - ds = DataStore(*args) # make a new datastore - - if ( cache_dir is not None ) and ( cache_base_name is not None ) and ( not os.getenv( "LIMIT_CACHE_USAGE" ) ): - ds.cache_base_name = os.path.join(cache_dir, cache_base_name) # save this for testing purposes - - p = pipeline_factory() - - # allow calling scope to override/augment parameters for any of the processing steps - p.override_parameters(**overrides) - p.augment_parameters(**augments) - - with SmartSession(session) as session: - code_version = session.merge(code_version) - - if ds.image is not None: # if starting from an externally provided Image, must merge it first - ds.image = ds.image.merge_all(session) - - ############ preprocessing to create image ############ - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( ds.image is None ) and ( cache_dir is not None ) and ( cache_base_name is not None ) - ): - # check if preprocessed image is in cache - cache_name = cache_base_name + '.image.fits.json' - cache_path = os.path.join(cache_dir, cache_name) - if os.path.isfile(cache_path): - SCLogger.debug('loading image from cache. ') - ds.image = copy_from_cache(Image, cache_dir, cache_name) - # assign the correct exposure to the object loaded from cache - if ds.exposure_id is not None: - ds.image.exposure_id = ds.exposure_id - if ds.exposure is not None: - ds.image.exposure = ds.exposure - - # Copy the original image from the cache if requested - if save_original_image: - ds.path_to_original_image = ds.image.get_fullpath()[0] + '.image.fits.original' - cache_path = os.path.join(cache_dir, ds.image.filepath + '.image.fits.original') - shutil.copy2( cache_path, ds.path_to_original_image ) - - # add the preprocessing steps from instrument (TODO: remove this as part of Issue #142) - # preprocessing_steps = ds.image.instrument_object.preprocessing_steps - # prep_pars = p.preprocessor.pars.get_critical_pars() - # prep_pars['preprocessing_steps'] = preprocessing_steps - - upstreams = [ds.exposure.provenance] if ds.exposure is not None else [] # images without exposure - prov = Provenance( - code_version=code_version, - process='preprocessing', - upstreams=upstreams, - parameters=p.preprocessor.pars.get_critical_pars(), - is_testing=True, - ) - prov = session.merge(prov) - session.commit() - - # if Image already exists on the database, use that instead of this one - existing = session.scalars(sa.select(Image).where(Image.filepath == ds.image.filepath)).first() - if existing is not None: - # overwrite the existing row data using the JSON cache file - for key in sa.inspect(ds.image).mapper.columns.keys(): - value = getattr(ds.image, key) - if ( - key not in ['id', 'image_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - ds.image = existing # replace with the existing row - - ds.image.provenance = prov - - # make sure this is saved to the archive as well - ds.image.save(verify_md5=False) - - if ds.image is None: # make the preprocessed image - SCLogger.debug('making preprocessed image. ') - ds = p.preprocessor.run(ds, session) - ds.image.provenance.is_testing = True - if bad_pixel_map is not None: - ds.image.flags |= bad_pixel_map - if ds.image.weight is not None: - ds.image.weight[ds.image.flags.astype(bool)] = 0.0 - - # flag saturated pixels, too (TODO: is there a better way to get the saturation limit? ) - mask = make_saturated_flag(ds.image.data, ds.image.instrument_object.saturation_limit, iterations=2) - ds.image.flags |= (mask * 2 ** BitFlagConverter.convert('saturated')).astype(np.uint16) - - ds.image.save() - if not os.getenv( "LIMIT_CACHE_USAGE" ): - output_path = copy_to_cache(ds.image, cache_dir) - - if cache_dir is not None and cache_base_name is not None and output_path != cache_path: - warnings.warn(f'cache path {cache_path} does not match output path {output_path}') - elif cache_dir is not None and cache_base_name is None: - ds.cache_base_name = output_path - SCLogger.debug(f'Saving image to cache at: {output_path}') - - # In test_astro_cal, there's a routine that needs the original - # image before being processed through the rest of what this - # factory function does, so save it if requested - if save_original_image: - ds.path_to_original_image = ds.image.get_fullpath()[0] + '.image.fits.original' - shutil.copy2( ds.image.get_fullpath()[0], ds.path_to_original_image ) - if not os.getenv( "LIMIT_CACHE_USAGE" ): - shutil.copy2( ds.image.get_fullpath()[0], - os.path.join(cache_dir, ds.image.filepath + '.image.fits.original') ) - - # check if background was calculated - if ds.image.bkg_mean_estimate is None or ds.image.bkg_rms_estimate is None: - # Estimate the background rms with sep - boxsize = ds.image.instrument_object.background_box_size - filtsize = ds.image.instrument_object.background_filt_size - - # Dysfunctionality alert: sep requires a *float* image for the mask - # IEEE 32-bit floats have 23 bits in the mantissa, so they should - # be able to precisely represent a 16-bit integer mask image - # In any event, sep.Background uses >0 as "bad" - fmask = np.array(ds.image.flags, dtype=np.float32) - backgrounder = sep.Background(ds.image.data, mask=fmask, - bw=boxsize, bh=boxsize, fw=filtsize, fh=filtsize) - - ds.image.bkg_mean_estimate = backgrounder.globalback - ds.image.bkg_rms_estimate = backgrounder.globalrms - - # TODO: move the code below here up to above preprocessing, once we have reference sets - try: # check if this datastore can load a reference - # this is a hack to tell the datastore that the given image's provenance is the right one to use - ref = ds.get_reference(session=session) - ref_prov = ref.provenance - except ValueError as e: - if 'No reference image found' in str(e): - ref = None - # make a placeholder reference just to be able to make a provenance tree - # this doesn't matter in this case, because if there is no reference - # then the datastore is returned without a subtraction, so all the - # provenances that have the reference provenances as upstream will - # not even exist. - - # TODO: we really should be working in a state where there is a reference set - # that has one provenance attached to it, that exists before we start up - # the pipeline. Here we are doing the opposite: we first check if a specific - # reference exists, and only then chose the provenance based on the available ref. - # TODO: once we have a reference that is independent of the image, we can move this - # code that makes the prov_tree up to before preprocessing - ref_prov = Provenance( - process='reference', - code_version=code_version, - parameters={}, - upstreams=[], - is_testing=True, - ) - else: - raise e # if any other error comes up, raise it - - ############# extraction to create sources / PSF / BG / WCS / ZP ############# - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( cache_dir is not None ) and ( cache_base_name is not None ) - ): - # try to get the SourceList, PSF, BG, WCS and ZP from cache - prov = Provenance( - code_version=code_version, - process='extraction', - upstreams=[ds.image.provenance], - parameters=p.extractor.pars.get_critical_pars(), # the siblings will be loaded automatically - is_testing=True, - ) - prov = session.merge(prov) - session.commit() - - cache_name = f'{cache_base_name}.sources_{prov.id[:6]}.fits.json' - sources_cache_path = os.path.join(cache_dir, cache_name) - if os.path.isfile(sources_cache_path): - SCLogger.debug('loading source list from cache. ') - ds.sources = copy_from_cache(SourceList, cache_dir, cache_name) - - # if SourceList already exists on the database, use that instead of this one - existing = session.scalars( - sa.select(SourceList).where(SourceList.filepath == ds.sources.filepath) - ).first() - if existing is not None: - # overwrite the existing row data using the JSON cache file - for key in sa.inspect(ds.sources).mapper.columns.keys(): - value = getattr(ds.sources, key) - if ( - key not in ['id', 'image_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - ds.sources = existing # replace with the existing row - - ds.sources.provenance = prov - ds.sources.image = ds.image - - # make sure this is saved to the archive as well - ds.sources.save(verify_md5=False) - - # try to get the PSF from cache - cache_name = f'{cache_base_name}.psf_{prov.id[:6]}.fits.json' - psf_cache_path = os.path.join(cache_dir, cache_name) - if os.path.isfile(psf_cache_path): - SCLogger.debug('loading PSF from cache. ') - ds.psf = copy_from_cache(PSF, cache_dir, cache_name) - - # if PSF already exists on the database, use that instead of this one - existing = session.scalars( - sa.select(PSF).where(PSF.filepath == ds.psf.filepath) - ).first() - if existing is not None: - # overwrite the existing row data using the JSON cache file - for key in sa.inspect(ds.psf).mapper.columns.keys(): - value = getattr(ds.psf, key) - if ( - key not in ['id', 'image_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - ds.psf = existing # replace with the existing row - - ds.psf.provenance = prov - ds.psf.image = ds.image - - # make sure this is saved to the archive as well - ds.psf.save(verify_md5=False, overwrite=True) - - # try to get the background from cache - cache_name = f'{cache_base_name}.bg_{prov.id[:6]}.h5.json' - bg_cache_path = os.path.join(cache_dir, cache_name) - if os.path.isfile(bg_cache_path): - SCLogger.debug('loading background from cache. ') - ds.bg = copy_from_cache(Background, cache_dir, cache_name) - - # if BG already exists on the database, use that instead of this one - existing = session.scalars( - sa.select(Background).where(Background.filepath == ds.bg.filepath) - ).first() - if existing is not None: - # overwrite the existing row data using the JSON cache file - for key in sa.inspect(ds.bg).mapper.columns.keys(): - value = getattr(ds.bg, key) - if ( - key not in ['id', 'image_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - ds.bg = existing - - ds.bg.provenance = prov - ds.bg.image = ds.image - - # make sure this is saved to the archive as well - ds.bg.save(verify_md5=False, overwrite=True) - - # try to get the WCS from cache - cache_name = f'{cache_base_name}.wcs_{prov.id[:6]}.txt.json' - wcs_cache_path = os.path.join(cache_dir, cache_name) - if os.path.isfile(wcs_cache_path): - SCLogger.debug('loading WCS from cache. ') - ds.wcs = copy_from_cache(WorldCoordinates, cache_dir, cache_name) - prov = session.merge(prov) - - # check if WCS already exists on the database - if ds.sources is not None: - existing = session.scalars( - sa.select(WorldCoordinates).where( - WorldCoordinates.sources_id == ds.sources.id, - WorldCoordinates.provenance_id == prov.id - ) - ).first() - else: - existing = None - - if existing is not None: - # overwrite the existing row data using the JSON cache file - for key in sa.inspect(ds.wcs).mapper.columns.keys(): - value = getattr(ds.wcs, key) - if ( - key not in ['id', 'sources_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - ds.wcs = existing # replace with the existing row - - ds.wcs.provenance = prov - ds.wcs.sources = ds.sources - # make sure this is saved to the archive as well - ds.wcs.save(verify_md5=False, overwrite=True) - - # try to get the ZP from cache - cache_name = cache_base_name + '.zp.json' - zp_cache_path = os.path.join(cache_dir, cache_name) - if os.path.isfile(zp_cache_path): - SCLogger.debug('loading zero point from cache. ') - ds.zp = copy_from_cache(ZeroPoint, cache_dir, cache_name) - - # check if ZP already exists on the database - if ds.sources is not None: - existing = session.scalars( - sa.select(ZeroPoint).where( - ZeroPoint.sources_id == ds.sources.id, - ZeroPoint.provenance_id == prov.id - ) - ).first() - else: - existing = None - - if existing is not None: - # overwrite the existing row data using the JSON cache file - for key in sa.inspect(ds.zp).mapper.columns.keys(): - value = getattr(ds.zp, key) - if ( - key not in ['id', 'sources_id', 'created_at', 'modified'] and - value is not None - ): - setattr(existing, key, value) - ds.zp = existing # replace with the existing row - - ds.zp.provenance = prov - ds.zp.sources = ds.sources - - # if any data product is missing, must redo the extraction step - if ds.sources is None or ds.psf is None or ds.bg is None or ds.wcs is None or ds.zp is None: - SCLogger.debug('extracting sources. ') - ds = p.extractor.run(ds, session) - - ds.sources.save(overwrite=True) - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( cache_dir is not None ) and ( cache_base_name is not None ) - ): - output_path = copy_to_cache(ds.sources, cache_dir) - if cache_dir is not None and cache_base_name is not None and output_path != sources_cache_path: - warnings.warn(f'cache path {sources_cache_path} does not match output path {output_path}') - - ds.psf.save(overwrite=True) - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( cache_dir is not None ) and ( cache_base_name is not None ) - ): - output_path = copy_to_cache(ds.psf, cache_dir) - if cache_dir is not None and cache_base_name is not None and output_path != psf_cache_path: - warnings.warn(f'cache path {psf_cache_path} does not match output path {output_path}') - - SCLogger.debug('Running background estimation') - ds = p.backgrounder.run(ds, session) - - ds.bg.save(overwrite=True) - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( cache_dir is not None ) and ( cache_base_name is not None ) - ): - output_path = copy_to_cache(ds.bg, cache_dir) - if cache_dir is not None and cache_base_name is not None and output_path != bg_cache_path: - warnings.warn(f'cache path {bg_cache_path} does not match output path {output_path}') - - SCLogger.debug('Running astrometric calibration') - ds = p.astrometor.run(ds, session) - ds.wcs.save(overwrite=True) - if ((cache_dir is not None) and (cache_base_name is not None) and - (not os.getenv("LIMIT_CACHE_USAGE"))): - output_path = copy_to_cache(ds.wcs, cache_dir) - if output_path != wcs_cache_path: - warnings.warn(f'cache path {wcs_cache_path} does not match output path {output_path}') - - SCLogger.debug('Running photometric calibration') - ds = p.photometor.run(ds, session) - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( cache_dir is not None ) and ( cache_base_name is not None ) - ): - output_path = copy_to_cache(ds.zp, cache_dir, cache_name) - if output_path != zp_cache_path: - warnings.warn(f'cache path {zp_cache_path} does not match output path {output_path}') - - ds.save_and_commit(session=session) - if ref is None: - return ds # if no reference is found, simply return the datastore without the rest of the products - - # try to find the subtraction image in the cache - if cache_dir is not None: - prov = Provenance( - code_version=code_version, - process='subtraction', - upstreams=[ - ds.image.provenance, - ds.sources.provenance, - ref.image.provenance, - ref.sources.provenance, - ], - parameters=p.subtractor.pars.get_critical_pars(), - is_testing=True, - ) - prov = session.merge(prov) - session.commit() - - sub_im = Image.from_new_and_ref(ds.image, ref.image) - sub_im.provenance = prov - cache_sub_name = sub_im.invent_filepath() - cache_name = cache_sub_name + '.image.fits.json' - if os.path.isfile(os.path.join(cache_dir, cache_name)): - SCLogger.debug('loading subtraction image from cache. ') - ds.sub_image = copy_from_cache(Image, cache_dir, cache_name) - - ds.sub_image.provenance = prov - ds.sub_image.upstream_images.append(ref.image) - ds.sub_image.ref_image_id = ref.image_id - ds.sub_image.new_image = ds.image - ds.sub_image.save(verify_md5=False) # make sure it is also saved to archive - - # try to load the aligned images from cache - prov_aligned_ref = Provenance( - code_version=code_version, - parameters={ - 'method': 'swarp', - 'to_index': 'new', - 'max_arcsec_residual': 0.2, - 'crossid_radius': 2.0, - 'max_sources_to_use': 2000, - 'min_frac_matched': 0.1, - 'min_matched': 10, - }, - upstreams=[ - ds.image.provenance, - ds.sources.provenance, # this also includes the PSF's provenance - ds.wcs.provenance, - ds.ref_image.provenance, - ds.ref_image.sources.provenance, - ds.ref_image.wcs.provenance, - ds.ref_image.zp.provenance, - ], - process='alignment', - is_testing=True, - ) - # TODO: can we find a less "hacky" way to do this? - f = ref.image.invent_filepath() - f = f.replace('ComSci', 'Warped') # not sure if this or 'Sci' will be in the filename - f = f.replace('Sci', 'Warped') # in any case, replace it with 'Warped' - f = f[:-6] + prov_aligned_ref.id[:6] # replace the provenance ID - filename_aligned_ref = f - - prov_aligned_new = Provenance( - code_version=code_version, - parameters=prov_aligned_ref.parameters, - upstreams=[ - ds.image.provenance, - ds.sources.provenance, # this also includes provs for PSF, BG, WCS, ZP - ], - process='alignment', - is_testing=True, - ) - f = ds.sub_image.new_image.invent_filepath() - f = f.replace('ComSci', 'Warped') - f = f.replace('Sci', 'Warped') - f = f[:-6] + prov_aligned_new.id[:6] - filename_aligned_new = f - - cache_name_ref = filename_aligned_ref + '.fits.json' - cache_name_new = filename_aligned_new + '.fits.json' - if ( - os.path.isfile(os.path.join(cache_dir, cache_name_ref)) and - os.path.isfile(os.path.join(cache_dir, cache_name_new)) - ): - SCLogger.debug('loading aligned reference image from cache. ') - image_aligned_ref = copy_from_cache(Image, cache_dir, cache_name) - image_aligned_ref.provenance = prov_aligned_ref - image_aligned_ref.info['original_image_id'] = ds.ref_image_id - image_aligned_ref.info['original_image_filepath'] = ds.ref_image.filepath - image_aligned_ref.save(verify_md5=False, no_archive=True) - # TODO: should we also load the aligned image's sources, PSF, and ZP? - - SCLogger.debug('loading aligned new image from cache. ') - image_aligned_new = copy_from_cache(Image, cache_dir, cache_name) - image_aligned_new.provenance = prov_aligned_new - image_aligned_new.info['original_image_id'] = ds.image_id - image_aligned_new.info['original_image_filepath'] = ds.image.filepath - image_aligned_new.save(verify_md5=False, no_archive=True) - # TODO: should we also load the aligned image's sources, PSF, and ZP? - - if image_aligned_ref.mjd < image_aligned_new.mjd: - ds.sub_image._aligned_images = [image_aligned_ref, image_aligned_new] - else: - ds.sub_image._aligned_images = [image_aligned_new, image_aligned_ref] - - if ds.sub_image is None: # no hit in the cache - ds = p.subtractor.run(ds, session) - ds.sub_image.save(verify_md5=False) # make sure it is also saved to archive - if not os.getenv( "LIMIT_CACHE_USAGE" ): - copy_to_cache(ds.sub_image, cache_dir) - - # make sure that the aligned images get into the cache, too - if ( - ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - 'cache_name_ref' in locals() and - os.path.isfile(os.path.join(cache_dir, cache_name_ref)) and - 'cache_name_new' in locals() and - os.path.isfile(os.path.join(cache_dir, cache_name_new)) - ): - for im in ds.sub_image.aligned_images: - copy_to_cache(im, cache_dir) - - ############ detecting to create a source list ############ - prov = Provenance( - code_version=code_version, - process='detection', - upstreams=[ds.sub_image.provenance], - parameters=p.detector.pars.get_critical_pars(), - is_testing=True, - ) - prov = session.merge(prov) - session.commit() - - cache_name = os.path.join(cache_dir, cache_sub_name + f'.sources_{prov.id[:6]}.npy.json') - if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_name) ): - SCLogger.debug('loading detections from cache. ') - ds.detections = copy_from_cache(SourceList, cache_dir, cache_name) - ds.detections.provenance = prov - ds.detections.image = ds.sub_image - ds.sub_image.sources = ds.detections - ds.detections.save(verify_md5=False) - else: # cannot find detections on cache - ds = p.detector.run(ds, session) - ds.detections.save(verify_md5=False) - if not os.getenv( "LIMIT_CACHE_USAGE" ): - copy_to_cache(ds.detections, cache_dir, cache_name) - - ############ cutting to create cutouts ############ - prov = Provenance( - code_version=code_version, - process='cutting', - upstreams=[ds.detections.provenance], - parameters=p.cutter.pars.get_critical_pars(), - is_testing=True, - ) - prov = session.merge(prov) - session.commit() - - cache_name = os.path.join(cache_dir, cache_sub_name + f'.cutouts_{prov.id[:6]}.h5') - if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_name) ): - SCLogger.debug('loading cutouts from cache. ') - ds.cutouts = copy_from_cache(Cutouts, cache_dir, cache_name) - setattr(ds.cutouts, 'provenance', prov) - setattr(ds.cutouts, 'sources', ds.detections) - ds.cutouts.load_all_co_data() # sources must be set first - ds.cutouts.save() - else: # cannot find cutouts on cache - ds = p.cutter.run(ds, session) - ds.cutouts.save() - if not os.getenv( "LIMIT_CACHE_USAGE" ): - copy_to_cache(ds.cutouts, cache_dir) - - ############ measuring to create measurements ############ - prov = Provenance( - code_version=code_version, - process='measuring', - upstreams=[ds.cutouts.provenance], - parameters=p.measurer.pars.get_critical_pars(), - is_testing=True, - ) - prov = session.merge(prov) - session.commit() - - cache_name = os.path.join(cache_dir, cache_sub_name + f'.measurements_{prov.id[:6]}.json') - - if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_name) ): - # note that the cache contains ALL the measurements, not only the good ones - SCLogger.debug('loading measurements from cache. ') - ds.all_measurements = copy_list_from_cache(Measurements, cache_dir, cache_name) - [setattr(m, 'provenance', prov) for m in ds.all_measurements] - [setattr(m, 'cutouts', ds.cutouts) for m in ds.all_measurements] - - ds.measurements = [] - for m in ds.all_measurements: - threshold_comparison = p.measurer.compare_measurement_to_thresholds(m) - if threshold_comparison != "delete": # all disqualifiers are below threshold - m.is_bad = threshold_comparison == "bad" - ds.measurements.append(m) - - [m.associate_object(session) for m in ds.measurements] # create or find an object for each measurement - # no need to save list because Measurements is not a FileOnDiskMixin! - else: # cannot find measurements on cache - ds = p.measurer.run(ds, session) - copy_list_to_cache(ds.all_measurements, cache_dir, cache_name) # must provide filepath! - - ds.save_and_commit(session=session) - - return ds - - return make_datastore +def refmaker_factory(test_config, pipeline_factory, coadd_pipeline_factory): + + def make_refmaker(name, instrument): + maker = RefMaker(maker={'name': name, 'instruments': [instrument]}) + maker.pars._enforce_no_new_attrs = False + maker.pars.test_parameter = maker.pars.add_par( + 'test_parameter', 'test_value', str, 'parameter to define unique tests', critical=True + ) + maker.pars._enforce_no_new_attrs = True + maker.pipeline = pipeline_factory() + maker.pipeline.override_parameters(**test_config.value('referencing.pipeline')) + maker.coadd_pipeline = coadd_pipeline_factory() + maker.coadd_pipeline.override_parameters(**test_config.value('referencing.coaddition')) + + return maker + + return make_refmaker diff --git a/tests/fixtures/ptf.py b/tests/fixtures/ptf.py index 5374c04d..3c777447 100644 --- a/tests/fixtures/ptf.py +++ b/tests/fixtures/ptf.py @@ -25,13 +25,12 @@ from models.zero_point import ZeroPoint from models.reference import Reference -from pipeline.coaddition import CoaddPipeline - from improc.alignment import ImageAligner from util.retrydownload import retry_download from util.logger import SCLogger from util.cache import copy_to_cache, copy_list_to_cache, copy_from_cache, copy_list_from_cache +from util.util import env_as_bool @pytest.fixture(scope='session') @@ -51,7 +50,7 @@ def ptf_bad_pixel_map(download_url, data_dir, ptf_cache_dir): data_dir = os.path.join(data_dir, 'PTF_calibrators') data_path = os.path.join(data_dir, filename) - if os.getenv( "LIMIT_CACHE_USAGE" ): + if env_as_bool( "LIMIT_CACHE_USAGE" ): if not os.path.isfile( data_path ): os.makedirs( os.path.dirname( data_path ), exist_ok=True ) retry_download( url + filename, data_path ) @@ -110,7 +109,7 @@ def download_ptf_function(filename='PTF201104291667_2_o_45737_11.w.fits'): # url = f'https://portal.nersc.gov/project/m2218/pipeline/test_images/{filename}' url = os.path.join(download_url, 'PTF/10cwm', filename) - if os.getenv( "LIMIT_CACHE_USAGE" ): + if env_as_bool( "LIMIT_CACHE_USAGE" ): retry_download( url, destination ) if not os.path.isfile( destination ): raise FileNotFoundError( f"Can't read {destination}. It should have been downloaded!" ) @@ -174,7 +173,7 @@ def ptf_datastore(datastore_factory, ptf_exposure, ptf_ref, ptf_cache_dir, ptf_b 11, cache_dir=ptf_cache_dir, cache_base_name='187/PTF_20110429_040004_11_R_Sci_BNKEKA', - overrides={'extraction': {'threshold': 5}}, + overrides={'extraction': {'threshold': 5}, 'subtraction': {'refset': 'test_refset_ptf'}}, bad_pixel_map=ptf_bad_pixel_map, ) yield ds @@ -211,7 +210,7 @@ def ptf_images_factory(ptf_urls, ptf_downloader, datastore_factory, ptf_cache_di def factory(start_date='2009-04-04', end_date='2013-03-03', max_images=None): # see if any of the cache names were saved to a manifest file cache_names = {} - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and + if ( ( not env_as_bool( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(os.path.join(ptf_cache_dir, 'manifest.txt')) ) ): with open(os.path.join(ptf_cache_dir, 'manifest.txt')) as f: @@ -250,9 +249,10 @@ def factory(start_date='2009-04-04', end_date='2013-03-03', max_images=None): bad_pixel_map=ptf_bad_pixel_map, ) - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and - ( hasattr(ds, 'cache_base_name') ) and ( ds.cache_base_name is not None ) - ): + if ( + not env_as_bool( "LIMIT_CACHE_USAGE" ) and + hasattr(ds, 'cache_base_name') and ds.cache_base_name is not None + ): cache_name = ds.cache_base_name if cache_name.startswith(ptf_cache_dir): cache_name = cache_name[len(ptf_cache_dir) + 1:] @@ -323,7 +323,7 @@ def ptf_aligned_images(request, ptf_cache_dir, data_dir, code_version): cache_dir = os.path.join(ptf_cache_dir, 'aligned_images') # try to load from cache - if ( ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and + if ( ( not env_as_bool( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(os.path.join(cache_dir, 'manifest.txt')) ) ): with open(os.path.join(cache_dir, 'manifest.txt')) as f: @@ -359,10 +359,12 @@ def ptf_aligned_images(request, ptf_cache_dir, data_dir, code_version): image.save() filepath = copy_to_cache(image, cache_dir) if image.psf.filepath is None: # save only PSF objects that haven't been saved yet + image.psf.provenance = coadd_image.upstream_images[0].psf.provenance image.psf.save(overwrite=True) if image.bg.filepath is None: # save only Background objects that haven't been saved yet + image.bg.provenance = coadd_image.upstream_images[0].bg.provenance image.bg.save(overwrite=True) - if not os.getenv( "LIMIT_CACHE_USAGE" ): + if not env_as_bool( "LIMIT_CACHE_USAGE" ): copy_to_cache(image.psf, cache_dir) copy_to_cache(image.bg, cache_dir) copy_to_cache(image.zp, cache_dir, filepath=filepath[:-len('.image.fits.json')]+'.zp.json') @@ -370,7 +372,7 @@ def ptf_aligned_images(request, ptf_cache_dir, data_dir, code_version): psf_paths.append(image.psf.filepath) bg_paths.append(image.bg.filepath) - if not os.getenv( "LIMIT_CACHE_USAGE" ): + if not env_as_bool( "LIMIT_CACHE_USAGE" ): os.makedirs(cache_dir, exist_ok=True) with open(os.path.join(cache_dir, 'manifest.txt'), 'w') as f: for filename, psf_path, bg_path in zip(filenames, psf_paths, bg_paths): @@ -383,7 +385,8 @@ def ptf_aligned_images(request, ptf_cache_dir, data_dir, code_version): if 'output_images' in locals(): for image in output_images: image.psf.delete_from_disk_and_database() - image.delete_from_disk_and_database() + image.bg.delete_from_disk_and_database() + image.delete_from_disk_and_database(remove_downstreams=True) if 'coadd_image' in locals(): coadd_image.delete_from_disk_and_database() @@ -395,10 +398,6 @@ def ptf_aligned_images(request, ptf_cache_dir, data_dir, code_version): action='ignore', message=r'.*DELETE statement on table .* expected to delete \d* row\(s\).*', ) - # warnings.filterwarnings( - # 'ignore', - # message=r".*Object of type .* not in session, .* operation along .* won't proceed.*" - # ) for image in ptf_reference_images: image = session.merge(image) image.exposure.delete_from_disk_and_database(commit=False, session=session, remove_downstreams=True) @@ -408,14 +407,15 @@ def ptf_aligned_images(request, ptf_cache_dir, data_dir, code_version): @pytest.fixture def ptf_ref( + refmaker_factory, ptf_reference_images, ptf_aligned_images, - coadd_pipeline_for_tests, ptf_cache_dir, data_dir, code_version ): - pipe = coadd_pipeline_for_tests + refmaker = refmaker_factory('test_ref_ptf', 'PTF') + pipe = refmaker.coadd_pipeline # build up the provenance tree with SmartSession() as session: @@ -450,7 +450,7 @@ def ptf_ref( ] filenames = [os.path.join(ptf_cache_dir, cache_base_name) + f'.{ext}.json' for ext in extensions] - if ( not os.getenv( "LIMIT_CACHE_USAGE" ) and + if ( not env_as_bool( "LIMIT_CACHE_USAGE" ) and all([os.path.isfile(filename) for filename in filenames]) ): # can load from cache # get the image: @@ -500,7 +500,7 @@ def ptf_ref( pipe.datastore.save_and_commit() coadd_image = pipe.datastore.image - if not os.getenv( "LIMIT_CACHE_USAGE" ): + if not env_as_bool( "LIMIT_CACHE_USAGE" ): # save all products into cache: copy_to_cache(pipe.datastore.image, ptf_cache_dir) copy_to_cache(pipe.datastore.sources, ptf_cache_dir) @@ -513,7 +513,7 @@ def ptf_ref( coadd_image = coadd_image.merge_all(session) ref = Reference(image=coadd_image) - ref.make_provenance() + ref.make_provenance(parameters=refmaker.pars.get_critical_pars()) ref.provenance.parameters['test_parameter'] = 'test_value' ref.provenance.is_testing = True ref.provenance.update_id() @@ -531,28 +531,91 @@ def ptf_ref( @pytest.fixture -def ptf_subtraction1(ptf_ref, ptf_supernova_images, subtractor, ptf_cache_dir): +def ptf_ref_offset(ptf_ref): + with SmartSession() as session: + offset_image = Image.copy_image(ptf_ref.image) + offset_image.ra_corner_00 -= 0.5 + offset_image.ra_corner_01 -= 0.5 + offset_image.ra_corner_10 -= 0.5 + offset_image.ra_corner_11 -= 0.5 + offset_image.filepath = ptf_ref.image.filepath + '_offset' + offset_image.provenance = ptf_ref.image.provenance + offset_image.md5sum = uuid.uuid4() # spoof this so we don't have to save to archive + + new_ref = Reference() + new_ref.image = offset_image + pars = ptf_ref.provenance.parameters.copy() + pars['test_parameter'] = uuid.uuid4().hex + prov = Provenance( + process='referencing', + parameters=pars, + upstreams=ptf_ref.provenance.upstreams, + code_version=ptf_ref.provenance.code_version, + is_testing=True, + ) + new_ref.provenance = prov + new_ref = session.merge(new_ref) + session.commit() + + yield new_ref + + new_ref.image.delete_from_disk_and_database() + + +@pytest.fixture(scope='session') +def ptf_refset(refmaker_factory): + refmaker = refmaker_factory('test_refset_ptf', 'PTF') + refmaker.pars.save_new_refs = True + + refmaker.make_refset() # this makes a refset without making any references + + yield refmaker.refset - cache_path = os.path.join(ptf_cache_dir, '187/PTF_20100216_075004_11_R_Diff_VXUBFA_u-7ogkop.image.fits.json') + # delete all the references and the refset + with SmartSession() as session: + for prov in refmaker.refset.provenances: + refs = session.scalars(sa.select(Reference).where(Reference.provenance_id == prov.id)).all() + for ref in refs: + session.delete(ref) + + session.delete(refmaker.refset) + + session.commit() + + +@pytest.fixture +def ptf_subtraction1(ptf_ref, ptf_supernova_images, subtractor, ptf_cache_dir): + subtractor.pars.refset = 'test_refset_ptf' + upstreams = [ + ptf_ref.image.provenance, + ptf_ref.image.sources.provenance, + ptf_supernova_images[0].provenance, + ptf_supernova_images[0].sources.provenance, + ] + prov = Provenance( + process='subtraction', + parameters=subtractor.pars.get_critical_pars(), + upstreams=upstreams, + code_version=ptf_ref.image.provenance.code_version, + is_testing=True, + ) + cache_path = os.path.join( + ptf_cache_dir, + f'187/PTF_20100216_075004_11_R_Diff_{prov.id[:6]}_u-iig7a2.image.fits.json' + ) - if ( not os.getenv( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_path) ): # try to load this from cache + if ( not env_as_bool( "LIMIT_CACHE_USAGE" ) ) and ( os.path.isfile(cache_path) ): # try to load this from cache im = copy_from_cache(Image, ptf_cache_dir, cache_path) im.upstream_images = [ptf_ref.image, ptf_supernova_images[0]] im.ref_image_id = ptf_ref.image.id - prov = Provenance( - process='subtraction', - parameters=subtractor.pars.get_critical_pars(), - upstreams=im.get_upstream_provenances(), - code_version=ptf_ref.image.provenance.code_version, - is_testing=True, - ) + im.provenance = prov else: # cannot find it on cache, need to produce it, using other fixtures ds = subtractor.run(ptf_supernova_images[0]) ds.sub_image.save() - if not os.getenv( "LIMIT_CACHE_USAGE" ) : + if not env_as_bool( "LIMIT_CACHE_USAGE" ) : copy_to_cache(ds.sub_image, ptf_cache_dir) im = ds.sub_image diff --git a/tests/fixtures/simulated.py b/tests/fixtures/simulated.py index 228a1b24..02641392 100644 --- a/tests/fixtures/simulated.py +++ b/tests/fixtures/simulated.py @@ -280,7 +280,7 @@ def sim_reference(provenance_preprocessing, provenance_extra): ref.image = ref_image ref.provenance = Provenance( code_version=provenance_extra.code_version, - process='reference', + process='referencing', parameters={'test_parameter': 'test_value'}, upstreams=[provenance_extra], is_testing=True, diff --git a/tests/improc/test_simulator.py b/tests/improc/test_simulator.py index 5aa13416..3759106e 100644 --- a/tests/improc/test_simulator.py +++ b/tests/improc/test_simulator.py @@ -10,12 +10,13 @@ from improc.simulator import Simulator, SimGalaxies, SimStreaks, SimCosmicRays from improc.sky_flat import sigma_clipping from util.logger import SCLogger +from util.util import env_as_bool # uncomment this to run the plotting tests interactively # os.environ['INTERACTIVE'] = '1' -@pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) +@pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) def test_make_star_field(blocking_plots): s = Simulator( image_size_x=256, star_number=1000, galaxy_number=0) s.make_image() @@ -30,7 +31,7 @@ def test_make_star_field(blocking_plots): plt.savefig(filename+'.pdf') -@pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) +@pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) def test_make_galaxy_field(blocking_plots): s = Simulator( image_size_x=256, star_number=0, galaxy_number=1000, galaxy_min_width=1, galaxy_min_flux=1000 ) t0 = time.time() @@ -159,7 +160,7 @@ def test_bleeding_pixels(blocking_plots): plt.show(block=True) -@pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) +@pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) def test_streak_images(blocking_plots): im = SimStreaks.make_streak_image(center_x=50.3, length=25) @@ -175,7 +176,7 @@ def test_streak_images(blocking_plots): plt.show(block=True) -@pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) +@pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) def test_track_images(blocking_plots): im = SimCosmicRays.make_track_image(center_x=50.3, length=25, energy=10) diff --git a/tests/improc/test_sky_flat.py b/tests/improc/test_sky_flat.py index 4dd9ece8..863c8d56 100644 --- a/tests/improc/test_sky_flat.py +++ b/tests/improc/test_sky_flat.py @@ -11,7 +11,7 @@ from util.logger import SCLogger -@pytest.mark.flaky(max_runs=3) +@pytest.mark.flaky(max_runs=6) @pytest.mark.parametrize("num_images", [10, 300]) def test_simple_sky_flat(num_images): clear_cache = False # cache the images from the simulator diff --git a/tests/improc/test_zogy.py b/tests/improc/test_zogy.py index 29ce4006..56b7e9db 100644 --- a/tests/improc/test_zogy.py +++ b/tests/improc/test_zogy.py @@ -10,7 +10,7 @@ from models.base import CODE_ROOT, safe_mkdir from improc.simulator import Simulator from improc.zogy import zogy_subtract -from util.logger import SCLogger +from util.util import env_as_bool imsize = 256 @@ -158,7 +158,7 @@ def test_subtraction_no_new_sources(): assert zogy_failures == 0 -@pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) +@pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) def test_subtraction_snr_histograms(blocking_plots): background = 5.0 seeing = 3.0 diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 46b23d1d..24da1e1c 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -5,8 +5,6 @@ import uuid import json -import sqlalchemy as sa - import numpy as np import pytest @@ -479,3 +477,50 @@ def test_fourcorners_sort_radec(): assert ras == [ -1.366, -0.366, 0.366, 1.366 ] assert decs == [ -0.366, 1.366, -1.366, 0.366 ] + +def test_four_corners_overlap_frac(): + dra = 0.75 + ddec = 0.375 + radec1 = [(10., -3.), (10., -45.), (10., -80.)] + + # TODO : add tests where things aren't perfectly square + for ra, dec in radec1: + cd = np.cos(dec * np.pi / 180.) + i1 = Image(ra=ra, dec=dec, + ra_corner_00=ra - dra / 2. / cd, + ra_corner_01=ra - dra / 2. / cd, + ra_corner_10=ra + dra / 2. / cd, + ra_corner_11=ra + dra / 2. / cd, + dec_corner_00=dec - ddec / 2., + dec_corner_10=dec - ddec / 2., + dec_corner_01=dec + ddec / 2., + dec_corner_11=dec + ddec / 2.) + for frac, offx, offy in [(1., 0., 0.), + (0.5, 0.5, 0.), + (0.5, -0.5, 0.), + (0.5, 0., 0.5), + (0.5, 0., -0.5), + (0.25, 0.5, 0.5), + (0.25, -0.5, 0.5), + (0.25, 0.5, -0.5), + (0.25, -0.5, -0.5), + (0., 1., 0.), + (0., -1., 0.), + (0., 1., 0.), + (0., -1., 0.), + (0., -1., -1.), + (0., 1., -1.)]: + ra2 = ra + offx * dra / cd + dec2 = dec + offy * ddec + i2 = Image(ra=ra2, dec=dec2, + ra_corner_00=ra2 - dra / 2. / cd, + ra_corner_01=ra2 - dra / 2. / cd, + ra_corner_10=ra2 + dra / 2. / cd, + ra_corner_11=ra2 + dra / 2. / cd, + dec_corner_00=dec2 - ddec / 2., + dec_corner_10=dec2 - ddec / 2., + dec_corner_01=dec2 + ddec / 2., + dec_corner_11=dec2 + ddec / 2.) + assert FourCorners.get_overlap_frac(i1, i2) == pytest.approx(frac, abs=0.01) + + diff --git a/tests/models/test_cutouts.py b/tests/models/test_cutouts.py index 2f2d41b3..afd48c3d 100644 --- a/tests/models/test_cutouts.py +++ b/tests/models/test_cutouts.py @@ -10,6 +10,7 @@ from models.base import SmartSession from models.cutouts import Cutouts + def test_make_save_load_cutouts(decam_detection_list, cutter): try: cutter.pars.test_parameter = uuid.uuid4().hex @@ -51,12 +52,11 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): assert np.array_equal(co_subdict.get(f'{im}_{att}'), file[subdict_key][f'{im}_{att}']) - # load a cutouts from file and compare c2 = Cutouts() c2.filepath = ds.cutouts.filepath c2.sources = ds.cutouts.sources # necessary for co_dict - c2.load_all_co_data() # explicitly load co_dict + c2.load_all_co_data() # explicitly load co_dict co_subdict2 = c2.co_dict[subdict_key] @@ -65,11 +65,11 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): assert np.array_equal(co_subdict.get(f'{im}_{att}'), co_subdict2.get(f'{im}_{att}')) - assert c2.bitflag == 0 # should not load all column data from file + assert c2.bitflag == 0 # should not load all column data from file # change the value of one of the arrays ds.cutouts.co_dict[subdict_key]['sub_data'][0, 0] = 100 - co_subdict2['sub_data'][0, 0] = 100 # for comparison later + co_subdict2['sub_data'][0, 0] = 100 # for comparison later # make sure we can re-save ds.cutouts.save() @@ -77,7 +77,7 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): with h5py.File(ds.cutouts.get_fullpath(), 'r') as file: assert np.array_equal(ds.cutouts.co_dict[subdict_key]['sub_data'], file[subdict_key]['sub_data']) - assert file[subdict_key]['sub_data'][0, 0] == 100 # change has propagated + assert file[subdict_key]['sub_data'][0, 0] == 100 # change has propagated # check that we can add the cutouts to the database with SmartSession() as session: @@ -103,7 +103,6 @@ def test_make_save_load_cutouts(decam_detection_list, cutter): assert np.array_equal(co_subdict.get(f'{im}_{att}'), co_subdict2.get(f'{im}_{att}')) - finally: if 'ds' in locals() and ds.cutouts is not None: ds.cutouts.delete_from_disk_and_database() diff --git a/tests/models/test_decam.py b/tests/models/test_decam.py index 1c0c2ed8..2f4b7017 100644 --- a/tests/models/test_decam.py +++ b/tests/models/test_decam.py @@ -6,7 +6,6 @@ import pytest import numpy as np -import sqlalchemy as sa from astropy.io import fits @@ -18,8 +17,10 @@ from models.image import Image from models.instrument import Instrument from models.decam import DECam + import util.radec from util.logger import SCLogger +from util.util import env_as_bool def test_decam_exposure(decam_filename): @@ -118,7 +119,7 @@ def test_image_from_decam_exposure(decam_filename, provenance_base, data_dir): # guidance for how to do things, do *not* write code that mucks about # with the _frame member of one of those objects; that's internal state # not intended for external consumption. -@pytest.mark.skipif( os.getenv('SKIP_NOIRLAB_DOWNLOADS'), reason="SKIP_NOIRLAB_DOWNLOADS is set" ) +@pytest.mark.skipif( env_as_bool('SKIP_NOIRLAB_DOWNLOADS'), reason="SKIP_NOIRLAB_DOWNLOADS is set" ) def test_decam_search_noirlab( decam_reduced_origin_exposures ): origloglevel = SCLogger.get().getEffectiveLevel() try: @@ -161,7 +162,7 @@ def test_decam_search_noirlab( decam_reduced_origin_exposures ): SCLogger.setLevel( origloglevel ) -@pytest.mark.skipif( os.getenv('SKIP_NOIRLAB_DOWNLOADS'), reason="SKIP_NOIRLAB_DOWNLOADS is set" ) +@pytest.mark.skipif( env_as_bool('SKIP_NOIRLAB_DOWNLOADS'), reason="SKIP_NOIRLAB_DOWNLOADS is set" ) def test_decam_download_origin_exposure( decam_reduced_origin_exposures, cache_dir ): assert all( [ row.proc_type == 'instcal' for i, row in decam_reduced_origin_exposures._frame.iterrows() ] ) try: @@ -201,12 +202,14 @@ def test_decam_download_origin_exposure( decam_reduced_origin_exposures, cache_d md5.update( ifp.read() ) assert md5.hexdigest() == decam_reduced_origin_exposures._frame.loc[ dex, extname ].md5sum - finally: - # Don't clean up for efficiency of rerunning tests. - pass + finally: # cleanup + for d in downloaded: + for path in d.values(): + if os.path.isfile( path ): + os.unlink( path ) -@pytest.mark.skipif( os.getenv('SKIP_NOIRLAB_DOWNLOADS'), reason="SKIP_NOIRLAB_DOWNLOADS is set" ) +@pytest.mark.skipif( env_as_bool('SKIP_NOIRLAB_DOWNLOADS'), reason="SKIP_NOIRLAB_DOWNLOADS is set" ) def test_decam_download_and_commit_exposure( code_version, decam_raw_origin_exposures, cache_dir, data_dir, test_config, archive ): @@ -271,15 +274,11 @@ def test_decam_download_and_commit_exposure( path = os.path.join(data_dir, d['exposure'].name) if os.path.isfile(path): os.unlink(path) - - if 'downloaded' in locals(): - for d in downloaded: - path = os.path.join(data_dir, d['exposure'].name) - if os.path.isfile(path): - os.unlink(path) + if os.path.isfile(d['exposure']): + os.unlink(d['exposure']) -@pytest.mark.skipif( os.getenv('RUN_SLOW_TESTS') is None, reason="Set RUN_SLOW_TESTS to run this test" ) +@pytest.mark.skipif( not env_as_bool('RUN_SLOW_TESTS'), reason="Set RUN_SLOW_TESTS to run this test" ) def test_get_default_calibrators( decam_default_calibrators ): sections, filters = decam_default_calibrators decam = get_instrument_instance( 'DECam' ) diff --git a/tests/models/test_image.py b/tests/models/test_image.py index d2d2cfca..5c76b99c 100644 --- a/tests/models/test_image.py +++ b/tests/models/test_image.py @@ -429,97 +429,6 @@ def test_image_enum_values(sim_image1): os.rmdir(folder) -def test_image_upstreams_downstreams(sim_image1, sim_reference, provenance_extra, data_dir): - with SmartSession() as session: - sim_image1 = sim_image1.merge_all(session) - sim_reference = sim_reference.merge_all(session) - - # make sure the new image matches the reference in all these attributes - sim_image1.filter = sim_reference.filter - sim_image1.target = sim_reference.target - sim_image1.section_id = sim_reference.section_id - - new = Image.from_new_and_ref(sim_image1, sim_reference.image) - new.provenance = session.merge(provenance_extra) - - # save and delete at the end - cleanup = ImageCleanup.save_image(new) - - session.add(new) - session.commit() - - # new make sure a new session can find all the upstreams/downstreams - with SmartSession() as session: - sim_image1 = sim_image1.merge_all(session) - sim_reference = sim_reference.merge_all(session) - new = new.merge_all(session) - - # check the upstreams/downstreams for the new image - upstream_ids = [u.id for u in new.get_upstreams(session=session)] - assert sim_image1.id in upstream_ids - assert sim_reference.image_id in upstream_ids - downstream_ids = [d.id for d in new.get_downstreams(session=session)] - assert len(downstream_ids) == 0 - - upstream_ids = [u.id for u in sim_image1.get_upstreams(session=session)] - assert [sim_image1.exposure_id] == upstream_ids - downstream_ids = [d.id for d in sim_image1.get_downstreams(session=session)] - assert [new.id] == downstream_ids # should be the only downstream - - # check the upstreams/downstreams for the reference image - upstreams = sim_reference.image.get_upstreams(session=session) - assert len(upstreams) == 5 # was made of five images - assert all([isinstance(u, Image) for u in upstreams]) - source_images_ids = [im.id for im in sim_reference.image.upstream_images] - upstream_ids = [u.id for u in upstreams] - assert set(upstream_ids) == set(source_images_ids) - downstream_ids = [d.id for d in sim_reference.image.get_downstreams(session=session)] - assert [new.id] == downstream_ids # should be the only downstream - - # test for the Image.downstream relationship - assert len(upstreams[0].downstream_images) == 1 - assert upstreams[0].downstream_images == [sim_reference.image] - assert len(upstreams[1].downstream_images) == 1 - assert upstreams[1].downstream_images == [sim_reference.image] - - assert len(sim_image1.downstream_images) == 1 - assert sim_image1.downstream_images == [new] - - assert len(sim_reference.image.downstream_images) == 1 - assert sim_reference.image.downstream_images == [new] - - assert len(new.downstream_images) == 0 - - # add a second "new" image using one of the reference's upstreams instead of the reference - new2 = Image.from_new_and_ref(sim_image1, upstreams[0]) - new2.provenance = session.merge(provenance_extra) - new2.mjd += 1 # make sure this image has a later MJD, so it comes out later on the downstream list! - - # save and delete at the end - cleanup2 = ImageCleanup.save_image(new2) - - session.add(new2) - session.commit() - - session.refresh(upstreams[0]) - assert len(upstreams[0].downstream_images) == 2 - assert set(upstreams[0].downstream_images) == set([sim_reference.image, new2]) - - session.refresh(upstreams[1]) - assert len(upstreams[1].downstream_images) == 1 - assert upstreams[1].downstream_images == [sim_reference.image] - - session.refresh(sim_image1) - assert len(sim_image1.downstream_images) == 2 - assert set(sim_image1.downstream_images) == set([new, new2]) - - session.refresh(sim_reference.image) - assert len(sim_reference.image.downstream_images) == 1 - assert sim_reference.image.downstream_images == [new] - - assert len(new2.downstream_images) == 0 - - def test_image_preproc_bitflag( sim_image1 ): with SmartSession() as session: @@ -533,479 +442,35 @@ def test_image_preproc_bitflag( sim_image1 ): im.preproc_bitflag |= string_to_bitflag( 'flat, overscan', image_preprocessing_inverse ) assert im.preproc_bitflag == string_to_bitflag( 'overscan, zero, flat', image_preprocessing_inverse ) - q = ( session.query( Image.filepath ) - .filter( Image.preproc_bitflag.op('&')(string_to_bitflag('zero', image_preprocessing_inverse) ) - != 0 ) ) - assert (im.filepath,) in q.all() - q = ( session.query( Image.filepath ) - .filter( Image.preproc_bitflag.op('&')(string_to_bitflag('zero,flat', image_preprocessing_inverse) ) - !=0 ) ) - assert (im.filepath,) in q.all() - q = ( session.query( Image.filepath ) - .filter( Image.preproc_bitflag.op('&')(string_to_bitflag('zero, flat', image_preprocessing_inverse ) ) - == string_to_bitflag( 'flat, zero', image_preprocessing_inverse ) ) ) - assert (im.filepath,) in q.all() - q = ( session.query( Image.filepath ) - .filter( Image.preproc_bitflag.op('&')(string_to_bitflag('fringe', image_preprocessing_inverse) ) - !=0 ) ) - assert (im.filepath,) not in q.all() - q = ( session.query( Image.filepath ) - .filter( Image.preproc_bitflag.op('&')(string_to_bitflag('fringe, overscan', - image_preprocessing_inverse) ) - == string_to_bitflag( 'overscan, fringe', image_preprocessing_inverse ) ) ) - assert q.count() == 0 - - -def test_image_badness(sim_image1): - - with SmartSession() as session: - sim_image1 = session.merge(sim_image1) - - # this is not a legit "badness" keyword... - with pytest.raises(ValueError, match='Keyword "foo" not recognized'): - sim_image1.badness = 'foo' - - # this is a legit keyword, but for cutouts, not for images - with pytest.raises(ValueError, match='Keyword "cosmic ray" not recognized'): - sim_image1.badness = 'cosmic ray' - - # this is a legit keyword, but for images, using no space and no capitalization - sim_image1.badness = 'brightsky' - - # retrieving this keyword, we do get it capitalized and with a space: - assert sim_image1.badness == 'bright sky' - assert sim_image1.bitflag == 2 ** 5 # the bright sky bit is number 5 - - # what happens when we add a second keyword? - sim_image1.badness = 'Bright_sky, Banding' # try this with capitalization and underscores - assert sim_image1.bitflag == 2 ** 5 + 2 ** 1 # the bright sky bit is number 5, banding is number 1 - assert sim_image1.badness == 'banding, bright sky' - - # now add a third keyword, but on the Exposure - sim_image1.exposure.badness = 'saturation' - session.add(sim_image1) - session.commit() - - # a manual way to propagate bitflags downstream - sim_image1.exposure.update_downstream_badness(session=session) # make sure the downstreams get the new badness - session.commit() - assert sim_image1.bitflag == 2 ** 5 + 2 ** 3 + 2 ** 1 # saturation bit is 3 - assert sim_image1.badness == 'banding, saturation, bright sky' - - # adding the same keyword on the exposure and the image makes no difference - sim_image1.exposure.badness = 'Banding' - sim_image1.exposure.update_downstream_badness(session=session) # make sure the downstreams get the new badness - session.commit() - assert sim_image1.bitflag == 2 ** 5 + 2 ** 1 - assert sim_image1.badness == 'banding, bright sky' - - # try appending keywords to the image - sim_image1.append_badness('shaking') - assert sim_image1.bitflag == 2 ** 5 + 2 ** 2 + 2 ** 1 # shaking bit is 2 - assert sim_image1.badness == 'banding, shaking, bright sky' - - -def test_multiple_images_badness( - sim_image1, - sim_image2, - sim_image3, - sim_image5, - sim_image6, - provenance_extra -): - try: - with SmartSession() as session: - sim_image1 = session.merge(sim_image1) - sim_image2 = session.merge(sim_image2) - sim_image3 = session.merge(sim_image3) - sim_image5 = session.merge(sim_image5) - sim_image6 = session.merge(sim_image6) - - images = [sim_image1, sim_image2, sim_image3, sim_image5, sim_image6] - cleanups = [] - filter = 'g' - target = str(uuid.uuid4()) - project = 'test project' - for im in images: - im.filter = filter - im.target = target - im.project = project - session.add(im) - - session.commit() - - # the image itself is marked bad because of bright sky - sim_image2.badness = 'BrightSky' - assert sim_image2.badness == 'bright sky' - assert sim_image2.bitflag == 2 ** 5 - session.commit() - - # note that this image is not directly bad, but the exposure has banding - sim_image3.exposure.badness = 'banding' - sim_image3.exposure.update_downstream_badness(session=session) - session.commit() - - assert sim_image3.badness == 'banding' - assert sim_image1._bitflag == 0 # the exposure is bad! - assert sim_image3.bitflag == 2 ** 1 - session.commit() - - # find the images that are good vs bad - good_images = session.scalars(sa.select(Image).where(Image.bitflag == 0)).all() - assert sim_image1.id in [i.id for i in good_images] - - bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() - assert sim_image2.id in [i.id for i in bad_images] - assert sim_image3.id in [i.id for i in bad_images] - - # make an image from the two bad exposures using subtraction - - sim_image4 = Image.from_new_and_ref(sim_image3, sim_image2) - sim_image4.provenance = provenance_extra - sim_image4.provenance.upstreams = sim_image4.get_upstream_provenances() - cleanups.append(ImageCleanup.save_image(sim_image4)) - sim_image4 = session.merge(sim_image4) - images.append(sim_image4) - session.commit() - - assert sim_image4.id is not None - assert sim_image4.ref_image == sim_image2 - assert sim_image4.new_image == sim_image3 - - # check that badness is loaded correctly from both parents - assert sim_image4.badness == 'banding, bright sky' - assert sim_image4._bitflag == 0 # the image itself is not flagged - assert sim_image4.bitflag == 2 ** 1 + 2 ** 5 - - # check that filtering on this value gives the right bitflag - bad_images = session.scalars(sa.select(Image).where(Image.bitflag == 2 ** 1 + 2 ** 5)).all() - assert sim_image4.id in [i.id for i in bad_images] - assert sim_image3.id not in [i.id for i in bad_images] - assert sim_image2.id not in [i.id for i in bad_images] - - # check that adding a badness on the image itself is added to the total badness - sim_image4.badness = 'saturation' - session.add(sim_image4) - session.commit() - assert sim_image4.badness == 'banding, saturation, bright sky' - assert sim_image4._bitflag == 2 ** 3 # only this bit is from the image itself - - # make a new subtraction: - sim_image7 = Image.from_ref_and_new(sim_image6, sim_image5) - sim_image7.provenance = provenance_extra - cleanups.append(ImageCleanup.save_image(sim_image7)) - sim_image7 = session.merge(sim_image7) - images.append(sim_image7) - session.commit() - - # check that the new subtraction is not flagged - assert sim_image7.badness == '' - assert sim_image7._bitflag == 0 - assert sim_image7.bitflag == 0 - - good_images = session.scalars(sa.select(Image).where(Image.bitflag == 0)).all() - assert sim_image5.id in [i.id for i in good_images] - assert sim_image5.id in [i.id for i in good_images] - assert sim_image7.id in [i.id for i in good_images] - - bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() - assert sim_image5.id not in [i.id for i in bad_images] - assert sim_image6.id not in [i.id for i in bad_images] - assert sim_image7.id not in [i.id for i in bad_images] - - # let's try to coadd an image based on some good and bad images - # as a reminder, sim_image2 has bright sky (5), - # sim_image3's exposure has banding (1), while - # sim_image4 has saturation (3). - - # make a coadded image (without including the subtraction sim_image4): - sim_image8 = Image.from_images([sim_image1, sim_image2, sim_image3, sim_image5, sim_image6]) - sim_image8.provenance = provenance_extra - cleanups.append(ImageCleanup.save_image(sim_image8)) - images.append(sim_image8) - sim_image8 = session.merge(sim_image8) - session.commit() - - assert sim_image8.badness == 'banding, bright sky' - assert sim_image8.bitflag == 2 ** 1 + 2 ** 5 - - # does this work in queries (i.e., using the bitflag hybrid expression)? - bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() - assert sim_image8.id in [i.id for i in bad_images] - bad_coadd = session.scalars(sa.select(Image).where(Image.bitflag == 2 ** 1 + 2 ** 5)).all() - assert sim_image8.id in [i.id for i in bad_coadd] - - # get rid of this coadd to make a new one - sim_image8.delete_from_disk_and_database(session=session) - cleanups.pop() - images.pop() - - # now let's add the subtraction image to the coadd: - # make a coadded image (now including the subtraction sim_image4): - sim_image8 = Image.from_images([sim_image1, sim_image2, sim_image3, sim_image4, sim_image5, sim_image6]) - sim_image8.provenance = provenance_extra - cleanups.append(ImageCleanup.save_image(sim_image8)) - sim_image8 = session.merge(sim_image8) - images.append(sim_image8) - session.commit() - - assert sim_image8.badness == 'banding, saturation, bright sky' - assert sim_image8.bitflag == 2 ** 1 + 2 ** 3 + 2 ** 5 # this should be 42 - - # does this work in queries (i.e., using the bitflag hybrid expression)? - bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() - assert sim_image8.id in [i.id for i in bad_images] - bad_coadd = session.scalars(sa.select(Image).where(Image.bitflag == 42)).all() - assert sim_image8.id in [i.id for i in bad_coadd] - - # try to add some badness to one of the underlying exposures - sim_image1.exposure.badness = 'shaking' - session.add(sim_image1) - sim_image1.exposure.update_downstream_badness(session=session) - session.commit() - - assert 'shaking' in sim_image1.badness - assert 'shaking' in sim_image8.badness - - finally: # cleanup - with SmartSession() as session: - session.autoflush = False - for im in images: - im = im.merge_all(session) - exp = im.exposure - im.delete_from_disk_and_database(session=session, commit=False) - - if exp is not None and sa.inspect(exp).persistent: - session.delete(exp) - - session.commit() - - -def test_image_coordinates(): - image = Image('coordinates.fits', ra=None, dec=None, nofile=True) - assert image.ecllat is None - assert image.ecllon is None - assert image.gallat is None - assert image.gallon is None - - image = Image('coordinates.fits', ra=123.4, dec=None, nofile=True) - assert image.ecllat is None - assert image.ecllon is None - assert image.gallat is None - assert image.gallon is None - - image = Image('coordinates.fits', ra=123.4, dec=56.78, nofile=True) - assert abs(image.ecllat - 35.846) < 0.01 - assert abs(image.ecllon - 111.838) < 0.01 - assert abs(image.gallat - 33.542) < 0.01 - assert abs(image.gallon - 160.922) < 0.01 - - -def test_image_cone_search( provenance_base ): - with SmartSession() as session: - image1 = None - image2 = None - image3 = None - image4 = None - try: - kwargs = { 'format': 'fits', - 'exp_time': 60.48, - 'section_id': 'x', - 'project': 'x', - 'target': 'x', - 'instrument': 'DemoInstrument', - 'telescope': 'x', - 'filter': 'r', - 'ra_corner_00': 0, - 'ra_corner_01': 0, - 'ra_corner_10': 0, - 'ra_corner_11': 0, - 'dec_corner_00': 0, - 'dec_corner_01': 0, - 'dec_corner_10': 0, - 'dec_corner_11': 0, - } - image1 = Image(ra=120., dec=10., provenance=provenance_base, **kwargs ) - image1.mjd = np.random.uniform(0, 1) + 60000 - image1.end_mjd = image1.mjd + 0.007 - clean1 = ImageCleanup.save_image( image1 ) - - image2 = Image(ra=120.0002, dec=9.9998, provenance=provenance_base, **kwargs ) - image2.mjd = np.random.uniform(0, 1) + 60000 - image2.end_mjd = image2.mjd + 0.007 - clean2 = ImageCleanup.save_image( image2 ) - - image3 = Image(ra=120.0005, dec=10., provenance=provenance_base, **kwargs ) - image3.mjd = np.random.uniform(0, 1) + 60000 - image3.end_mjd = image3.mjd + 0.007 - clean3 = ImageCleanup.save_image( image3 ) - - image4 = Image(ra=60., dec=0., provenance=provenance_base, **kwargs ) - image4.mjd = np.random.uniform(0, 1) + 60000 - image4.end_mjd = image4.mjd + 0.007 - clean4 = ImageCleanup.save_image( image4 ) - - session.add( image1 ) - session.add( image2 ) - session.add( image3 ) - session.add( image4 ) - - sought = session.query( Image ).filter( Image.cone_search(120., 10., rad=1.02) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id }.issubset( soughtids ) - assert len( { image3.id, image4.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.cone_search(120., 10., rad=2.) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id, image3.id }.issubset( soughtids ) - assert len( { image4.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.cone_search(120., 10., 0.017, radunit='arcmin') ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id }.issubset( soughtids ) - assert len( { image3.id, image4.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.cone_search(120., 10., 0.0002833, radunit='degrees') ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id }.issubset( soughtids ) - assert len( { image3.id, image4.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.cone_search(120., 10., 4.9451e-6, radunit='radians') ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id }.issubset( soughtids ) - assert len( { image3.id, image4.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.cone_search(60, -10, 1.) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert len( { image1.id, image2.id, image3.id, image4.id } & soughtids ) == 0 - - with pytest.raises( ValueError, match='.*unknown radius unit' ): - sought = Image.cone_search( 0., 0., 1., 'undefined_unit' ) - finally: - for i in [ image1, image2, image3, image4 ]: - if ( i is not None ) and sa.inspect( i ).persistent: - session.delete( i ) - session.commit() - - -# Really, we should also do some speed tests, but that -# is outside the scope of the always-run tests. -def test_four_corners( provenance_base ): - - with SmartSession() as session: - image1 = None - image2 = None - image3 = None - image4 = None - try: - kwargs = { 'format': 'fits', - 'exp_time': 60.48, - 'section_id': 'x', - 'project': 'x', - 'target': 'x', - 'instrument': 'DemoInstrument', - 'telescope': 'x', - 'filter': 'r', - } - # RA numbers are made ugly from cos(dec). - # image1: centered on 120, 40, square to the sky - image1 = Image( ra=120, dec=40., - ra_corner_00=119.86945927, ra_corner_01=119.86945927, - ra_corner_10=120.13054073, ra_corner_11=120.13054073, - dec_corner_00=39.9, dec_corner_01=40.1, dec_corner_10=39.9, dec_corner_11=40.1, - provenance=provenance_base, nofile=True, **kwargs ) - image1.mjd = np.random.uniform(0, 1) + 60000 - image1.end_mjd = image1.mjd + 0.007 - clean1 = ImageCleanup.save_image( image1 ) - - # image2: centered on 120, 40, at a 45° angle - image2 = Image( ra=120, dec=40., - ra_corner_00=119.81538753, ra_corner_01=120, ra_corner_11=120.18461247, ra_corner_10=120, - dec_corner_00=40, dec_corner_01=40.14142136, dec_corner_11=40, dec_corner_10=39.85857864, - provenance=provenance_base, nofile=True, **kwargs ) - image2.mjd = np.random.uniform(0, 1) + 60000 - image2.end_mjd = image2.mjd + 0.007 - clean2 = ImageCleanup.save_image( image2 ) - - # image3: centered offset by (0.025, 0.025) linear arcsec from 120, 40, square on sky - image3 = Image( ra=120.03264714, dec=40.025, - ra_corner_00=119.90210641, ra_corner_01=119.90210641, - ra_corner_10=120.16318787, ra_corner_11=120.16318787, - dec_corner_00=39.975, dec_corner_01=40.125, dec_corner_10=39.975, dec_corner_11=40.125, - provenance=provenance_base, nofile=True, **kwargs ) - image3.mjd = np.random.uniform(0, 1) + 60000 - image3.end_mjd = image3.mjd + 0.007 - clean3 = ImageCleanup.save_image( image3 ) - - # imagepoint and imagefar are used to test Image.containing and Image.find_containing, - # as Image is the only example of a SpatiallyIndexed thing we have so far. - # The corners don't matter for these given how they'll be used. - imagepoint = Image( ra=119.88, dec=39.95, - ra_corner_00=-.001, ra_corner_01=0.001, ra_corner_10=-0.001, - ra_corner_11=0.001, dec_corner_00=0, dec_corner_01=0, dec_corner_10=0, dec_corner_11=0, - provenance=provenance_base, nofile=True, **kwargs ) - imagepoint.mjd = np.random.uniform(0, 1) + 60000 - imagepoint.end_mjd = imagepoint.mjd + 0.007 - clearpoint = ImageCleanup.save_image( imagepoint ) - - imagefar = Image( ra=30, dec=-10, - ra_corner_00=0, ra_corner_01=0, ra_corner_10=0, - ra_corner_11=0, dec_corner_00=0, dec_corner_01=0, dec_corner_10=0, dec_corner_11=0, - provenance=provenance_base, nofile=True, **kwargs ) - imagefar.mjd = np.random.uniform(0, 1) + 60000 - imagefar.end_mjd = imagefar.mjd + 0.007 - clearfar = ImageCleanup.save_image( imagefar ) - - session.add( image1 ) - session.add( image2 ) - session.add( image3 ) - session.add( imagepoint ) - session.add( imagefar ) - - sought = session.query( Image ).filter( Image.containing( 120, 40 ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id, image3.id }.issubset( soughtids ) - assert len( { imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.containing( 119.88, 39.95 ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id }.issubset( soughtids ) - assert len( { image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.containing( 120, 40.12 ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image2.id, image3.id }.issubset( soughtids ) - assert len( { image1.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.containing( 120, 39.88 ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image2.id }.issubset( soughtids ) - assert len( { image1.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = Image.find_containing( imagepoint, session=session ) - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id }.issubset( soughtids ) - assert len( { image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.containing( 0, 0 ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert len( { image1.id, image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = Image.find_containing( imagefar, session=session ) - soughtids = set( [ s.id for s in sought ] ) - assert len( { image1.id, image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.within( image1 ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert { image1.id, image2.id, image3.id, imagepoint.id }.issubset( soughtids ) - assert len( { imagefar.id } & soughtids ) == 0 - - sought = session.query( Image ).filter( Image.within( imagefar ) ).all() - soughtids = set( [ s.id for s in sought ] ) - assert len( { image1.id, image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 - - finally: - session.rollback() + images = session.scalars(sa.select(Image).where( + Image.preproc_bitflag.op('&')(string_to_bitflag('zero', image_preprocessing_inverse)) != 0 + )).all() + assert im.id in [i.id for i in images] + + images = session.scalars(sa.select(Image).where( + Image.preproc_bitflag.op('&')(string_to_bitflag('zero,flat', image_preprocessing_inverse)) !=0 + )).all() + assert im.id in [i.id for i in images] + + images = session.scalars(sa.select(Image).where( + Image.preproc_bitflag.op('&')( + string_to_bitflag('zero, flat', image_preprocessing_inverse) + ) == string_to_bitflag('flat, zero', image_preprocessing_inverse) + )).all() + assert im.id in [i.id for i in images] + + images = session.scalars(sa.select(Image).where( + Image.preproc_bitflag.op('&')(string_to_bitflag('fringe', image_preprocessing_inverse) ) !=0 + )).all() + assert im.id not in [i.id for i in images] + + images = session.scalars(sa.select(Image.filepath).where( + Image.id == im.id, # only find the original image, if any + Image.preproc_bitflag.op('&')( + string_to_bitflag('fringe, overscan', image_preprocessing_inverse) + ) == string_to_bitflag( 'overscan, fringe', image_preprocessing_inverse ) + )).all() + assert len(images) == 0 def test_image_from_exposure(sim_exposure1, provenance_base): @@ -1386,8 +851,7 @@ def test_image_products_are_deleted(ptf_datastore, data_dir, archive): assert not os.path.isfile(file) -# @pytest.mark.flaky(max_runs=3) -@pytest.mark.skip(reason="We aren't succeeding at controlling garbage collection") +@pytest.mark.skip(reason="This test regularly fails, even when flaky is used. See Issue #263") def test_free( decam_exposure, decam_raw_image, ptf_ref ): proc = psutil.Process() origmem = proc.memory_info() diff --git a/tests/models/test_image_propagation.py b/tests/models/test_image_propagation.py new file mode 100644 index 00000000..ed4a68f5 --- /dev/null +++ b/tests/models/test_image_propagation.py @@ -0,0 +1,323 @@ +import pytest +import uuid +import sqlalchemy as sa +from models.base import SmartSession +from models.image import Image +from tests.fixtures.simulated import ImageCleanup + + +def test_image_upstreams_downstreams(sim_image1, sim_reference, provenance_extra, data_dir): + with SmartSession() as session: + sim_image1 = sim_image1.merge_all(session) + sim_reference = sim_reference.merge_all(session) + + # make sure the new image matches the reference in all these attributes + sim_image1.filter = sim_reference.filter + sim_image1.target = sim_reference.target + sim_image1.section_id = sim_reference.section_id + + new = Image.from_new_and_ref(sim_image1, sim_reference.image) + new.provenance = session.merge(provenance_extra) + + # save and delete at the end + cleanup = ImageCleanup.save_image(new) + + session.add(new) + session.commit() + + # new make sure a new session can find all the upstreams/downstreams + with SmartSession() as session: + sim_image1 = sim_image1.merge_all(session) + sim_reference = sim_reference.merge_all(session) + new = new.merge_all(session) + + # check the upstreams/downstreams for the new image + upstream_ids = [u.id for u in new.get_upstreams(session=session)] + assert sim_image1.id in upstream_ids + assert sim_reference.image_id in upstream_ids + downstream_ids = [d.id for d in new.get_downstreams(session=session)] + assert len(downstream_ids) == 0 + + upstream_ids = [u.id for u in sim_image1.get_upstreams(session=session)] + assert [sim_image1.exposure_id] == upstream_ids + downstream_ids = [d.id for d in sim_image1.get_downstreams(session=session)] + assert [new.id] == downstream_ids # should be the only downstream + + # check the upstreams/downstreams for the reference image + upstreams = sim_reference.image.get_upstreams(session=session) + assert len(upstreams) == 5 # was made of five images + assert all([isinstance(u, Image) for u in upstreams]) + source_images_ids = [im.id for im in sim_reference.image.upstream_images] + upstream_ids = [u.id for u in upstreams] + assert set(upstream_ids) == set(source_images_ids) + downstream_ids = [d.id for d in sim_reference.image.get_downstreams(session=session)] + assert [new.id] == downstream_ids # should be the only downstream + + # test for the Image.downstream relationship + assert len(upstreams[0].downstream_images) == 1 + assert upstreams[0].downstream_images == [sim_reference.image] + assert len(upstreams[1].downstream_images) == 1 + assert upstreams[1].downstream_images == [sim_reference.image] + + assert len(sim_image1.downstream_images) == 1 + assert sim_image1.downstream_images == [new] + + assert len(sim_reference.image.downstream_images) == 1 + assert sim_reference.image.downstream_images == [new] + + assert len(new.downstream_images) == 0 + + # add a second "new" image using one of the reference's upstreams instead of the reference + new2 = Image.from_new_and_ref(sim_image1, upstreams[0]) + new2.provenance = session.merge(provenance_extra) + new2.mjd += 1 # make sure this image has a later MJD, so it comes out later on the downstream list! + + # save and delete at the end + cleanup2 = ImageCleanup.save_image(new2) + + session.add(new2) + session.commit() + + session.refresh(upstreams[0]) + assert len(upstreams[0].downstream_images) == 2 + assert set(upstreams[0].downstream_images) == set([sim_reference.image, new2]) + + session.refresh(upstreams[1]) + assert len(upstreams[1].downstream_images) == 1 + assert upstreams[1].downstream_images == [sim_reference.image] + + session.refresh(sim_image1) + assert len(sim_image1.downstream_images) == 2 + assert set(sim_image1.downstream_images) == set([new, new2]) + + session.refresh(sim_reference.image) + assert len(sim_reference.image.downstream_images) == 1 + assert sim_reference.image.downstream_images == [new] + + assert len(new2.downstream_images) == 0 + + +def test_image_badness(sim_image1): + + with SmartSession() as session: + sim_image1 = session.merge(sim_image1) + + # this is not a legit "badness" keyword... + with pytest.raises(ValueError, match='Keyword "foo" not recognized'): + sim_image1.badness = 'foo' + + # this is a legit keyword, but for cutouts, not for images + with pytest.raises(ValueError, match='Keyword "cosmic ray" not recognized'): + sim_image1.badness = 'cosmic ray' + + # this is a legit keyword, but for images, using no space and no capitalization + sim_image1.badness = 'brightsky' + + # retrieving this keyword, we do get it capitalized and with a space: + assert sim_image1.badness == 'bright sky' + assert sim_image1.bitflag == 2 ** 5 # the bright sky bit is number 5 + + # what happens when we add a second keyword? + sim_image1.badness = 'Bright_sky, Banding' # try this with capitalization and underscores + assert sim_image1.bitflag == 2 ** 5 + 2 ** 1 # the bright sky bit is number 5, banding is number 1 + assert sim_image1.badness == 'banding, bright sky' + + # now add a third keyword, but on the Exposure + sim_image1.exposure.badness = 'saturation' + session.add(sim_image1) + session.commit() + + # a manual way to propagate bitflags downstream + sim_image1.exposure.update_downstream_badness(session=session) # make sure the downstreams get the new badness + session.commit() + assert sim_image1.bitflag == 2 ** 5 + 2 ** 3 + 2 ** 1 # saturation bit is 3 + assert sim_image1.badness == 'banding, saturation, bright sky' + + # adding the same keyword on the exposure and the image makes no difference + sim_image1.exposure.badness = 'Banding' + sim_image1.exposure.update_downstream_badness(session=session) # make sure the downstreams get the new badness + session.commit() + assert sim_image1.bitflag == 2 ** 5 + 2 ** 1 + assert sim_image1.badness == 'banding, bright sky' + + # try appending keywords to the image + sim_image1.append_badness('shaking') + assert sim_image1.bitflag == 2 ** 5 + 2 ** 2 + 2 ** 1 # shaking bit is 2 + assert sim_image1.badness == 'banding, shaking, bright sky' + + +def test_multiple_images_badness( + sim_image1, + sim_image2, + sim_image3, + sim_image5, + sim_image6, + provenance_extra +): + try: + with SmartSession() as session: + sim_image1 = session.merge(sim_image1) + sim_image2 = session.merge(sim_image2) + sim_image3 = session.merge(sim_image3) + sim_image5 = session.merge(sim_image5) + sim_image6 = session.merge(sim_image6) + + images = [sim_image1, sim_image2, sim_image3, sim_image5, sim_image6] + cleanups = [] + filter = 'g' + target = str(uuid.uuid4()) + project = 'test project' + for im in images: + im.filter = filter + im.target = target + im.project = project + session.add(im) + + session.commit() + + # the image itself is marked bad because of bright sky + sim_image2.badness = 'BrightSky' + assert sim_image2.badness == 'bright sky' + assert sim_image2.bitflag == 2 ** 5 + session.commit() + + # note that this image is not directly bad, but the exposure has banding + sim_image3.exposure.badness = 'banding' + sim_image3.exposure.update_downstream_badness(session=session) + session.commit() + + assert sim_image3.badness == 'banding' + assert sim_image1._bitflag == 0 # the exposure is bad! + assert sim_image3.bitflag == 2 ** 1 + session.commit() + + # find the images that are good vs bad + good_images = session.scalars(sa.select(Image).where(Image.bitflag == 0)).all() + assert sim_image1.id in [i.id for i in good_images] + + bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() + assert sim_image2.id in [i.id for i in bad_images] + assert sim_image3.id in [i.id for i in bad_images] + + # make an image from the two bad exposures using subtraction + + sim_image4 = Image.from_new_and_ref(sim_image3, sim_image2) + sim_image4.provenance = provenance_extra + sim_image4.provenance.upstreams = sim_image4.get_upstream_provenances() + cleanups.append(ImageCleanup.save_image(sim_image4)) + sim_image4 = session.merge(sim_image4) + images.append(sim_image4) + session.commit() + + assert sim_image4.id is not None + assert sim_image4.ref_image == sim_image2 + assert sim_image4.new_image == sim_image3 + + # check that badness is loaded correctly from both parents + assert sim_image4.badness == 'banding, bright sky' + assert sim_image4._bitflag == 0 # the image itself is not flagged + assert sim_image4.bitflag == 2 ** 1 + 2 ** 5 + + # check that filtering on this value gives the right bitflag + bad_images = session.scalars(sa.select(Image).where(Image.bitflag == 2 ** 1 + 2 ** 5)).all() + assert sim_image4.id in [i.id for i in bad_images] + assert sim_image3.id not in [i.id for i in bad_images] + assert sim_image2.id not in [i.id for i in bad_images] + + # check that adding a badness on the image itself is added to the total badness + sim_image4.badness = 'saturation' + session.add(sim_image4) + session.commit() + assert sim_image4.badness == 'banding, saturation, bright sky' + assert sim_image4._bitflag == 2 ** 3 # only this bit is from the image itself + + # make a new subtraction: + sim_image7 = Image.from_ref_and_new(sim_image6, sim_image5) + sim_image7.provenance = provenance_extra + cleanups.append(ImageCleanup.save_image(sim_image7)) + sim_image7 = session.merge(sim_image7) + images.append(sim_image7) + session.commit() + + # check that the new subtraction is not flagged + assert sim_image7.badness == '' + assert sim_image7._bitflag == 0 + assert sim_image7.bitflag == 0 + + good_images = session.scalars(sa.select(Image).where(Image.bitflag == 0)).all() + assert sim_image5.id in [i.id for i in good_images] + assert sim_image5.id in [i.id for i in good_images] + assert sim_image7.id in [i.id for i in good_images] + + bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() + assert sim_image5.id not in [i.id for i in bad_images] + assert sim_image6.id not in [i.id for i in bad_images] + assert sim_image7.id not in [i.id for i in bad_images] + + # let's try to coadd an image based on some good and bad images + # as a reminder, sim_image2 has bright sky (5), + # sim_image3's exposure has banding (1), while + # sim_image4 has saturation (3). + + # make a coadded image (without including the subtraction sim_image4): + sim_image8 = Image.from_images([sim_image1, sim_image2, sim_image3, sim_image5, sim_image6]) + sim_image8.provenance = provenance_extra + cleanups.append(ImageCleanup.save_image(sim_image8)) + images.append(sim_image8) + sim_image8 = session.merge(sim_image8) + session.commit() + + assert sim_image8.badness == 'banding, bright sky' + assert sim_image8.bitflag == 2 ** 1 + 2 ** 5 + + # does this work in queries (i.e., using the bitflag hybrid expression)? + bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() + assert sim_image8.id in [i.id for i in bad_images] + bad_coadd = session.scalars(sa.select(Image).where(Image.bitflag == 2 ** 1 + 2 ** 5)).all() + assert sim_image8.id in [i.id for i in bad_coadd] + + # get rid of this coadd to make a new one + sim_image8.delete_from_disk_and_database(session=session) + cleanups.pop() + images.pop() + + # now let's add the subtraction image to the coadd: + # make a coadded image (now including the subtraction sim_image4): + sim_image8 = Image.from_images([sim_image1, sim_image2, sim_image3, sim_image4, sim_image5, sim_image6]) + sim_image8.provenance = provenance_extra + cleanups.append(ImageCleanup.save_image(sim_image8)) + sim_image8 = session.merge(sim_image8) + images.append(sim_image8) + session.commit() + + assert sim_image8.badness == 'banding, saturation, bright sky' + assert sim_image8.bitflag == 2 ** 1 + 2 ** 3 + 2 ** 5 # this should be 42 + + # does this work in queries (i.e., using the bitflag hybrid expression)? + bad_images = session.scalars(sa.select(Image).where(Image.bitflag != 0)).all() + assert sim_image8.id in [i.id for i in bad_images] + bad_coadd = session.scalars(sa.select(Image).where(Image.bitflag == 42)).all() + assert sim_image8.id in [i.id for i in bad_coadd] + + # try to add some badness to one of the underlying exposures + sim_image1.exposure.badness = 'shaking' + session.add(sim_image1) + sim_image1.exposure.update_downstream_badness(session=session) + session.commit() + + assert 'shaking' in sim_image1.badness + assert 'shaking' in sim_image8.badness + + finally: # cleanup + with SmartSession() as session: + session.autoflush = False + for im in images: + im = im.merge_all(session) + exp = im.exposure + im.delete_from_disk_and_database(session=session, commit=False) + + if exp is not None and sa.inspect(exp).persistent: + session.delete(exp) + + session.commit() + diff --git a/tests/models/test_image_querying.py b/tests/models/test_image_querying.py new file mode 100644 index 00000000..6b034ead --- /dev/null +++ b/tests/models/test_image_querying.py @@ -0,0 +1,685 @@ +import pytest + +import numpy as np +import sqlalchemy as sa + +from astropy.time import Time + +from models.base import SmartSession +from models.image import Image, image_upstreams_association_table + +from tests.fixtures.simulated import ImageCleanup + + +def test_image_coordinates(): + image = Image('coordinates.fits', ra=None, dec=None, nofile=True) + assert image.ecllat is None + assert image.ecllon is None + assert image.gallat is None + assert image.gallon is None + + image = Image('coordinates.fits', ra=123.4, dec=None, nofile=True) + assert image.ecllat is None + assert image.ecllon is None + assert image.gallat is None + assert image.gallon is None + + image = Image('coordinates.fits', ra=123.4, dec=56.78, nofile=True) + assert abs(image.ecllat - 35.846) < 0.01 + assert abs(image.ecllon - 111.838) < 0.01 + assert abs(image.gallat - 33.542) < 0.01 + assert abs(image.gallon - 160.922) < 0.01 + + +def test_image_cone_search( provenance_base ): + with SmartSession() as session: + image1 = None + image2 = None + image3 = None + image4 = None + try: + kwargs = { 'format': 'fits', + 'exp_time': 60.48, + 'section_id': 'x', + 'project': 'x', + 'target': 'x', + 'instrument': 'DemoInstrument', + 'telescope': 'x', + 'filter': 'r', + 'ra_corner_00': 0, + 'ra_corner_01': 0, + 'ra_corner_10': 0, + 'ra_corner_11': 0, + 'dec_corner_00': 0, + 'dec_corner_01': 0, + 'dec_corner_10': 0, + 'dec_corner_11': 0, + } + image1 = Image(ra=120., dec=10., provenance=provenance_base, **kwargs ) + image1.mjd = np.random.uniform(0, 1) + 60000 + image1.end_mjd = image1.mjd + 0.007 + clean1 = ImageCleanup.save_image( image1 ) + + image2 = Image(ra=120.0002, dec=9.9998, provenance=provenance_base, **kwargs ) + image2.mjd = np.random.uniform(0, 1) + 60000 + image2.end_mjd = image2.mjd + 0.007 + clean2 = ImageCleanup.save_image( image2 ) + + image3 = Image(ra=120.0005, dec=10., provenance=provenance_base, **kwargs ) + image3.mjd = np.random.uniform(0, 1) + 60000 + image3.end_mjd = image3.mjd + 0.007 + clean3 = ImageCleanup.save_image( image3 ) + + image4 = Image(ra=60., dec=0., provenance=provenance_base, **kwargs ) + image4.mjd = np.random.uniform(0, 1) + 60000 + image4.end_mjd = image4.mjd + 0.007 + clean4 = ImageCleanup.save_image( image4 ) + + session.add( image1 ) + session.add( image2 ) + session.add( image3 ) + session.add( image4 ) + + sought = session.query( Image ).filter( Image.cone_search(120., 10., rad=1.02) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id }.issubset( soughtids ) + assert len( { image3.id, image4.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.cone_search(120., 10., rad=2.) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id, image3.id }.issubset( soughtids ) + assert len( { image4.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.cone_search(120., 10., 0.017, radunit='arcmin') ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id }.issubset( soughtids ) + assert len( { image3.id, image4.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.cone_search(120., 10., 0.0002833, radunit='degrees') ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id }.issubset( soughtids ) + assert len( { image3.id, image4.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.cone_search(120., 10., 4.9451e-6, radunit='radians') ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id }.issubset( soughtids ) + assert len( { image3.id, image4.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.cone_search(60, -10, 1.) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert len( { image1.id, image2.id, image3.id, image4.id } & soughtids ) == 0 + + with pytest.raises( ValueError, match='.*unknown radius unit' ): + sought = Image.cone_search( 0., 0., 1., 'undefined_unit' ) + finally: + for i in [ image1, image2, image3, image4 ]: + if ( i is not None ) and sa.inspect( i ).persistent: + session.delete( i ) + session.commit() + + +# Really, we should also do some speed tests, but that +# is outside the scope of the always-run tests. +def test_four_corners( provenance_base ): + + with SmartSession() as session: + image1 = None + image2 = None + image3 = None + image4 = None + try: + kwargs = { 'format': 'fits', + 'exp_time': 60.48, + 'section_id': 'x', + 'project': 'x', + 'target': 'x', + 'instrument': 'DemoInstrument', + 'telescope': 'x', + 'filter': 'r', + } + # RA numbers are made ugly from cos(dec). + # image1: centered on 120, 40, square to the sky + image1 = Image( ra=120, dec=40., + ra_corner_00=119.86945927, ra_corner_01=119.86945927, + ra_corner_10=120.13054073, ra_corner_11=120.13054073, + dec_corner_00=39.9, dec_corner_01=40.1, dec_corner_10=39.9, dec_corner_11=40.1, + provenance=provenance_base, nofile=True, **kwargs ) + image1.mjd = np.random.uniform(0, 1) + 60000 + image1.end_mjd = image1.mjd + 0.007 + clean1 = ImageCleanup.save_image( image1 ) + + # image2: centered on 120, 40, at a 45° angle + image2 = Image( ra=120, dec=40., + ra_corner_00=119.81538753, ra_corner_01=120, ra_corner_11=120.18461247, ra_corner_10=120, + dec_corner_00=40, dec_corner_01=40.14142136, dec_corner_11=40, dec_corner_10=39.85857864, + provenance=provenance_base, nofile=True, **kwargs ) + image2.mjd = np.random.uniform(0, 1) + 60000 + image2.end_mjd = image2.mjd + 0.007 + clean2 = ImageCleanup.save_image( image2 ) + + # image3: centered offset by (0.025, 0.025) linear arcsec from 120, 40, square on sky + image3 = Image( ra=120.03264714, dec=40.025, + ra_corner_00=119.90210641, ra_corner_01=119.90210641, + ra_corner_10=120.16318787, ra_corner_11=120.16318787, + dec_corner_00=39.975, dec_corner_01=40.125, dec_corner_10=39.975, dec_corner_11=40.125, + provenance=provenance_base, nofile=True, **kwargs ) + image3.mjd = np.random.uniform(0, 1) + 60000 + image3.end_mjd = image3.mjd + 0.007 + clean3 = ImageCleanup.save_image( image3 ) + + # imagepoint and imagefar are used to test Image.containing and Image.find_containing, + # as Image is the only example of a SpatiallyIndexed thing we have so far. + # The corners don't matter for these given how they'll be used. + imagepoint = Image( ra=119.88, dec=39.95, + ra_corner_00=-.001, ra_corner_01=0.001, ra_corner_10=-0.001, + ra_corner_11=0.001, dec_corner_00=0, dec_corner_01=0, dec_corner_10=0, dec_corner_11=0, + provenance=provenance_base, nofile=True, **kwargs ) + imagepoint.mjd = np.random.uniform(0, 1) + 60000 + imagepoint.end_mjd = imagepoint.mjd + 0.007 + clearpoint = ImageCleanup.save_image( imagepoint ) + + imagefar = Image( ra=30, dec=-10, + ra_corner_00=0, ra_corner_01=0, ra_corner_10=0, + ra_corner_11=0, dec_corner_00=0, dec_corner_01=0, dec_corner_10=0, dec_corner_11=0, + provenance=provenance_base, nofile=True, **kwargs ) + imagefar.mjd = np.random.uniform(0, 1) + 60000 + imagefar.end_mjd = imagefar.mjd + 0.007 + clearfar = ImageCleanup.save_image( imagefar ) + + session.add( image1 ) + session.add( image2 ) + session.add( image3 ) + session.add( imagepoint ) + session.add( imagefar ) + + sought = session.query( Image ).filter( Image.containing( 120, 40 ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id, image3.id }.issubset( soughtids ) + assert len( { imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.containing( 119.88, 39.95 ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id }.issubset( soughtids ) + assert len( { image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.containing( 120, 40.12 ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image2.id, image3.id }.issubset( soughtids ) + assert len( { image1.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.containing( 120, 39.88 ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image2.id }.issubset( soughtids ) + assert len( { image1.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = Image.find_containing( imagepoint, session=session ) + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id }.issubset( soughtids ) + assert len( { image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.containing( 0, 0 ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert len( { image1.id, image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = Image.find_containing( imagefar, session=session ) + soughtids = set( [ s.id for s in sought ] ) + assert len( { image1.id, image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.within( image1 ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert { image1.id, image2.id, image3.id, imagepoint.id }.issubset( soughtids ) + assert len( { imagefar.id } & soughtids ) == 0 + + sought = session.query( Image ).filter( Image.within( imagefar ) ).all() + soughtids = set( [ s.id for s in sought ] ) + assert len( { image1.id, image2.id, image3.id, imagepoint.id, imagefar.id } & soughtids ) == 0 + + finally: + session.rollback() + + +def im_qual(im, factor=3.0): + """Helper function to get the "quality" of an image.""" + return im.lim_mag_estimate - factor * im.fwhm_estimate + + +def test_image_query(ptf_ref, decam_reference, decam_datastore, decam_default_calibrators): + # TODO: need to fix some of these values (of lim_mag and quality) once we get actual limiting magnitude measurements + + with SmartSession() as session: + stmt = Image.query_images() + results = session.scalars(stmt).all() + total = len(results) + + # from pprint import pprint + # pprint(results) + # + # print(f'MJD: {[im.mjd for im in results]}') + # print(f'date: {[im.observation_time for im in results]}') + # print(f'RA: {[im.ra for im in results]}') + # print(f'DEC: {[im.dec for im in results]}') + # print(f'target: {[im.target for im in results]}') + # print(f'section_id: {[im.section_id for im in results]}') + # print(f'project: {[im.project for im in results]}') + # print(f'Instrument: {[im.instrument for im in results]}') + # print(f'Filter: {[im.filter for im in results]}') + # print(f'FWHM: {[im.fwhm_estimate for im in results]}') + # print(f'LIMMAG: {[im.lim_mag_estimate for im in results]}') + # print(f'B/G: {[im.bkg_rms_estimate for im in results]}') + # print(f'ZP: {[im.zero_point_estimate for im in results]}') + # print(f'EXPTIME: {[im.exp_time for im in results]}') + # print(f'AIRMASS: {[im.airmass for im in results]}') + # print(f'QUAL: {[im_qual(im) for im in results]}') + + # get only the science images + stmt = Image.query_images(type=1) + results1 = session.scalars(stmt).all() + assert all(im._type == 1 for im in results1) + assert all(im.type == 'Sci' for im in results1) + assert len(results1) < total + + # get the coadd and subtraction images + stmt = Image.query_images(type=[2, 3, 4]) + results2 = session.scalars(stmt).all() + assert all(im._type in [2, 3, 4] for im in results2) + assert all(im.type in ['ComSci', 'Diff', 'ComDiff'] for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + # use the names of the types instead of integers, or a mixture of ints and strings + stmt = Image.query_images(type=['ComSci', 'Diff', 4]) + results3 = session.scalars(stmt).all() + assert results2 == results3 + + # filter by MJD and observation date + value = 58000.0 + stmt = Image.query_images(min_mjd=value) + results1 = session.scalars(stmt).all() + assert all(im.mjd >= value for im in results1) + assert all(im.instrument == 'DECam' for im in results1) + assert len(results1) < total + + stmt = Image.query_images(max_mjd=value) + results2 = session.scalars(stmt).all() + assert all(im.mjd <= value for im in results2) + assert all(im.instrument == 'PTF' for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + stmt = Image.query_images(min_mjd=value, max_mjd=value) + results3 = session.scalars(stmt).all() + assert len(results3) == 0 + + # filter by observation date + t = Time(58000.0, format='mjd').datetime + stmt = Image.query_images(min_dateobs=t) + results4 = session.scalars(stmt).all() + assert all(im.observation_time >= t for im in results4) + assert all(im.instrument == 'DECam' for im in results4) + assert set(results4) == set(results1) + assert len(results4) < total + + stmt = Image.query_images(max_dateobs=t) + results5 = session.scalars(stmt).all() + assert all(im.observation_time <= t for im in results5) + assert all(im.instrument == 'PTF' for im in results5) + assert set(results5) == set(results2) + assert len(results5) < total + assert len(results4) + len(results5) == total + + # filter by images that contain this point (DECaPS-West) + ra = 115.28 + dec = -26.33 + + stmt = Image.query_images(ra=ra, dec=dec) + results1 = session.scalars(stmt).all() + assert all(im.instrument == 'DECam' for im in results1) + assert all(im.target == 'DECaPS-West' for im in results1) + assert len(results1) < total + + # filter by images that contain this point (PTF field number 100014) + ra = 188.0 + dec = 4.5 + stmt = Image.query_images(ra=ra, dec=dec) + results2 = session.scalars(stmt).all() + assert all(im.instrument == 'PTF' for im in results2) + assert all(im.target == '100014' for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + # filter by section ID + stmt = Image.query_images(section_id='N1') + results1 = session.scalars(stmt).all() + assert all(im.section_id == 'N1' for im in results1) + assert all(im.instrument == 'DECam' for im in results1) + assert len(results1) < total + + stmt = Image.query_images(section_id='11') + results2 = session.scalars(stmt).all() + assert all(im.section_id == '11' for im in results2) + assert all(im.instrument == 'PTF' for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + # filter by the PTF project name + stmt = Image.query_images(project='PTF_DyC_survey') + results1 = session.scalars(stmt).all() + assert all(im.project == 'PTF_DyC_survey' for im in results1) + assert all(im.instrument == 'PTF' for im in results1) + assert len(results1) < total + + # filter by the two different project names for DECam: + stmt = Image.query_images(project=['DECaPS', '2022A-724693']) + results2 = session.scalars(stmt).all() + assert all(im.project in ['DECaPS', '2022A-724693'] for im in results2) + assert all(im.instrument == 'DECam' for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + # filter by instrument + stmt = Image.query_images(instrument='PTF') + results1 = session.scalars(stmt).all() + assert all(im.instrument == 'PTF' for im in results1) + assert len(results1) < total + + stmt = Image.query_images(instrument='DECam') + results2 = session.scalars(stmt).all() + assert all(im.instrument == 'DECam' for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + stmt = Image.query_images(instrument=['PTF', 'DECam']) + results3 = session.scalars(stmt).all() + assert len(results3) == total + + stmt = Image.query_images(instrument=['foobar']) + results4 = session.scalars(stmt).all() + assert len(results4) == 0 + + # filter by filter + stmt = Image.query_images(filter='R') + results6 = session.scalars(stmt).all() + assert all(im.filter == 'R' for im in results6) + assert all(im.instrument == 'PTF' for im in results6) + assert set(results6) == set(results1) + + stmt = Image.query_images(filter='g DECam SDSS c0001 4720.0 1520.0') + results7 = session.scalars(stmt).all() + assert all(im.filter == 'g DECam SDSS c0001 4720.0 1520.0' for im in results7) + assert all(im.instrument == 'DECam' for im in results7) + assert set(results7) == set(results2) + + # filter by seeing FWHM + value = 3.5 + stmt = Image.query_images(min_seeing=value) + results1 = session.scalars(stmt).all() + assert all(im.fwhm_estimate >= value for im in results1) + assert len(results1) < total + + stmt = Image.query_images(max_seeing=value) + results2 = session.scalars(stmt).all() + assert all(im.fwhm_estimate <= value for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + stmt = Image.query_images(min_seeing=value, max_seeing=value) + results3 = session.scalars(stmt).all() + assert len(results3) == 0 # we will never have exactly that number + + # filter by limiting magnitude + value = 25.0 + stmt = Image.query_images(min_lim_mag=value) + results1 = session.scalars(stmt).all() + assert all(im.lim_mag_estimate >= value for im in results1) + assert len(results1) < total + + stmt = Image.query_images(max_lim_mag=value) + results2 = session.scalars(stmt).all() + assert all(im.lim_mag_estimate <= value for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + stmt = Image.query_images(min_lim_mag=value, max_lim_mag=value) + results3 = session.scalars(stmt).all() + assert len(results3) == 0 + + # filter by background + value = 25.0 + stmt = Image.query_images(min_background=value) + results1 = session.scalars(stmt).all() + assert all(im.bkg_rms_estimate >= value for im in results1) + assert len(results1) < total + + stmt = Image.query_images(max_background=value) + results2 = session.scalars(stmt).all() + assert all(im.bkg_rms_estimate <= value for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + stmt = Image.query_images(min_background=value, max_background=value) + results3 = session.scalars(stmt).all() + assert len(results3) == 0 + + # filter by zero point + value = 27.0 + stmt = Image.query_images(min_zero_point=value) + results1 = session.scalars(stmt).all() + assert all(im.zero_point_estimate >= value for im in results1) + assert len(results1) < total + + stmt = Image.query_images(max_zero_point=value) + results2 = session.scalars(stmt).all() + assert all(im.zero_point_estimate <= value for im in results2) + assert len(results2) < total + assert len(results1) + len(results2) == total + + stmt = Image.query_images(min_zero_point=value, max_zero_point=value) + results3 = session.scalars(stmt).all() + assert len(results3) == 0 + + # filter by exposure time + value = 60.0 + 1.0 + stmt = Image.query_images(min_exp_time=value) + results1 = session.scalars(stmt).all() + assert all(im.exp_time >= value for im in results1) + assert len(results1) < total + + stmt = Image.query_images(max_exp_time=value) + results2 = session.scalars(stmt).all() + assert all(im.exp_time <= value for im in results2) + assert len(results2) < total + + stmt = Image.query_images(min_exp_time=60.0, max_exp_time=60.0) + results3 = session.scalars(stmt).all() + assert len(results3) == len(results2) # all those under 31s are those with exactly 30s + + # query based on airmass + value = 1.15 + total_with_airmass = len([im for im in results if im.airmass is not None]) + stmt = Image.query_images(max_airmass=value) + results1 = session.scalars(stmt).all() + assert all(im.airmass <= value for im in results1) + assert len(results1) < total_with_airmass + + stmt = Image.query_images(min_airmass=value) + results2 = session.scalars(stmt).all() + assert all(im.airmass >= value for im in results2) + assert len(results2) < total_with_airmass + assert len(results1) + len(results2) == total_with_airmass + + # order the results by quality (lim_mag - 3 * fwhm) + # note that we cannot filter by quality, it is not a meaningful number + # on its own, only as a way to compare images and find which is better. + # sort all the images by quality and get the best one + stmt = Image.query_images(order_by='quality') + best = session.scalars(stmt).first() + + # the best overall quality from all images + assert im_qual(best) == max([im_qual(im) for im in results]) + + # get the two best images from the PTF instrument (exp_time chooses the single images only) + stmt = Image.query_images(max_exp_time=60, order_by='quality') + results1 = session.scalars(stmt.limit(2)).all() + assert len(results1) == 2 + assert all(im_qual(im) > 10.0 for im in results1) + + # change the seeing factor a little: + factor = 2.8 + stmt = Image.query_images(max_exp_time=60, order_by='quality', seeing_quality_factor=factor) + results2 = session.scalars(stmt.limit(2)).all() + + # quality will be a little bit higher, but the images are the same + assert results2 == results1 + assert im_qual(results2[0], factor=factor) > im_qual(results1[0]) + assert im_qual(results2[1], factor=factor) > im_qual(results1[1]) + + # change the seeing factor dramatically: + factor = 0.2 + stmt = Image.query_images(max_exp_time=60, order_by='quality', seeing_quality_factor=factor) + results3 = session.scalars(stmt.limit(2)).all() + + # quality will be a higher, but also a different image will now have the second-best quality + assert results3 != results1 + assert im_qual(results3[0], factor=factor) > im_qual(results1[0]) + + # do a cross filtering of coordinates and background (should only find the PTF coadd) + ra = 188.0 + dec = 4.5 + background = 5 + + stmt = Image.query_images(ra=ra, dec=dec, max_background=background) + results1 = session.scalars(stmt).all() + assert len(results1) == 1 + assert results1[0].instrument == 'PTF' + assert results1[0].type == 'ComSci' + + # cross the DECam target and section ID with long exposure time + target = 'DECaPS-West' + section_id = 'N1' + exp_time = 400.0 + + stmt = Image.query_images(target=target, section_id=section_id, min_exp_time=exp_time) + results2 = session.scalars(stmt).all() + assert len(results2) == 1 + assert results2[0].instrument == 'DECam' + assert results2[0].type == 'Sci' + assert results2[0].exp_time == 576.0 + + # cross filter on MJD and instrument in a way that has no results + mjd = 55000.0 + instrument = 'PTF' + + stmt = Image.query_images(min_mjd=mjd, instrument=instrument) + results3 = session.scalars(stmt).all() + assert len(results3) == 0 + + # cross filter MJD and sort by quality to get the coadd PTF image + mjd = 54926.31913 + + stmt = Image.query_images(max_mjd=mjd, order_by='quality') + results4 = session.scalars(stmt).all() + assert len(results4) == 2 + assert results4[0].mjd == results4[1].mjd # same time, as one is a coadd of the other images + assert results4[0].instrument == 'PTF' + assert results4[0].type == 'ComSci' # the first one out is the high quality coadd + assert results4[1].type == 'Sci' # the second one is the regular image + + # check that the DECam difference and new image it is based on have the same limiting magnitude and quality + stmt = Image.query_images(instrument='DECam', type=3) + diff = session.scalars(stmt).first() + stmt = Image.query_images(instrument='DECam', type=1, min_mjd=diff.mjd, max_mjd=diff.mjd) + new = session.scalars(stmt).first() + assert diff.lim_mag_estimate == new.lim_mag_estimate + assert diff.fwhm_estimate == new.fwhm_estimate + assert im_qual(diff) == im_qual(new) + + +def test_image_get_downstream(ptf_ref, ptf_supernova_images, ptf_subtraction1): + with SmartSession() as session: + # how many image to image associations are on the DB right now? + num_associations = session.execute( + sa.select(sa.func.count()).select_from(image_upstreams_association_table) + ).scalar() + + assert num_associations > len(ptf_ref.image.upstream_images) + + prov = ptf_ref.image.provenance + assert prov.process == 'coaddition' + images = ptf_ref.image.upstream_images + assert len(images) > 1 + + loaded_image = Image.get_image_from_upstreams(images, prov.id) + + assert loaded_image.id == ptf_ref.image.id + assert loaded_image.id != ptf_subtraction1.id + assert loaded_image.id != ptf_subtraction1.new_image.id + + new_image = None + new_image2 = None + new_image3 = None + try: + # make a new image with a new provenance + new_image = Image.copy_image(ptf_ref.image) + prov = ptf_ref.provenance + prov.process = 'copy' + new_image.provenance = prov + new_image.upstream_images = ptf_ref.image.upstream_images + new_image.save() + + with SmartSession() as session: + new_image = session.merge(new_image) + session.commit() + assert new_image.id is not None + assert new_image.id != ptf_ref.image.id + + loaded_image = Image.get_image_from_upstreams(images, prov.id) + assert loaded_image.id == new_image.id + + # use the original provenance but take down an image from the upstreams + prov = ptf_ref.image.provenance + images = ptf_ref.image.upstream_images[1:] + + new_image2 = Image.copy_image(ptf_ref.image) + new_image2.provenance = prov + new_image2.upstream_images = images + new_image2.save() + + with SmartSession() as session: + new_image2 = session.merge(new_image2) + session.commit() + assert new_image2.id is not None + assert new_image2.id != ptf_ref.image.id + assert new_image2.id != new_image.id + + loaded_image = Image.get_image_from_upstreams(images, prov.id) + assert loaded_image.id != ptf_ref.image.id + assert loaded_image.id != new_image.id + + # use the original provenance but add images to the upstreams + prov = ptf_ref.image.provenance + images = ptf_ref.image.upstream_images + ptf_supernova_images + + new_image3 = Image.copy_image(ptf_ref.image) + new_image3.provenance = prov + new_image3.upstream_images = images + new_image3.save() + + with SmartSession() as session: + new_image3 = session.merge(new_image3) + session.commit() + assert new_image3.id is not None + assert new_image3.id != ptf_ref.image.id + assert new_image3.id != new_image.id + assert new_image3.id != new_image2.id + + loaded_image = Image.get_image_from_upstreams(images, prov.id) + assert loaded_image.id == new_image3.id + + finally: + if new_image is not None: + new_image.delete_from_disk_and_database() + if new_image2 is not None: + new_image2.delete_from_disk_and_database() + if new_image3 is not None: + new_image3.delete_from_disk_and_database() + diff --git a/tests/models/test_measurements.py b/tests/models/test_measurements.py index f02f4e4d..6d52b2aa 100644 --- a/tests/models/test_measurements.py +++ b/tests/models/test_measurements.py @@ -52,7 +52,7 @@ def test_measurements_attributes(measurer, ptf_datastore, test_config): # check that background is subtracted from the "flux" and "magnitude" properties if m.best_aperture == -1: assert m.flux == m.flux_psf - m.bkg_mean * m.area_psf - assert m.magnitude > m.mag_psf # the magnitude has background subtracted from it + assert m.magnitude != m.mag_psf # the magnitude has background subtracted from it assert m.magnitude_err > m.mag_psf_err # the magnitude error is larger because of the error in background else: assert m.flux == m.flux_apertures[m.best_aperture] - m.bkg_mean * m.area_apertures[m.best_aperture] @@ -86,39 +86,8 @@ def test_measurements_attributes(measurer, ptf_datastore, test_config): # TODO: add test for limiting magnitude (issue #143) -@pytest.mark.skip(reason="This test fails on GA but not locally, see issue #306") -# @pytest.mark.flaky(max_runs=3) def test_filtering_measurements(ptf_datastore): - # printout the list of relevant environmental variables: - import os - print("SeeChange environment variables:") - for key in [ - 'INTERACTIVE', - 'LIMIT_CACHE_USAGE', - 'SKIP_NOIRLAB_DOWNLOADS', - 'RUN_SLOW_TESTS', - 'SEECHANGE_TRACEMALLOC', - ]: - print(f'{key}: {os.getenv(key)}') - measurements = ptf_datastore.measurements - from pprint import pprint - print('measurements: ') - pprint(measurements) - - if hasattr(ptf_datastore, 'all_measurements'): - idx = [m.cutouts.index_in_sources for m in measurements] - chosen = np.array(ptf_datastore.all_measurements)[idx] - pprint([(m, m.is_bad, m.cutouts.sub_nandata[12, 12]) for m in chosen]) - - print(f'new image values: {ptf_datastore.image.data[250, 240:250]}') - print(f'ref_image values: {ptf_datastore.ref_image.data[250, 240:250]}') - print(f'sub_image values: {ptf_datastore.sub_image.data[250, 240:250]}') - - print(f'number of images in ref image: {len(ptf_datastore.ref_image.upstream_images)}') - for i, im in enumerate(ptf_datastore.ref_image.upstream_images): - print(f'upstream image {i}: {im.data[250, 240:250]}') - m = measurements[0] # grab the first one as an example # test that we can filter on some measurements properties diff --git a/tests/models/test_psf.py b/tests/models/test_psf.py index 7547e264..013d5367 100644 --- a/tests/models/test_psf.py +++ b/tests/models/test_psf.py @@ -1,8 +1,6 @@ import pytest import io -import os import psutil -import gc import time import uuid import random @@ -23,6 +21,8 @@ from models.base import SmartSession, FileOnDiskMixin, CODE_ROOT, get_archive_object from models.psf import PSF +from util.util import env_as_bool + class PSFPaletteMaker: def __init__( self, round=False ): @@ -344,8 +344,7 @@ def test_save_psf( ztf_datastore_uncommitted, provenance_base, provenance_extra im.delete_from_disk_and_database(session=session) -# @pytest.mark.flaky(max_runs=3) -@pytest.mark.skip(reason="We aren't succeeding at controlling garbage collection") +@pytest.mark.skip(reason="This test regularly fails, even when flaky is used. See Issue #263") def test_free( decam_datastore ): ds = decam_datastore ds.get_psf() @@ -392,7 +391,7 @@ def test_free( decam_datastore ): assert origmem.rss - freemem.rss > 60 * 1024 * 1024 -@pytest.mark.skipif( os.getenv('RUN_SLOW_TESTS') is None, reason="Set RUN_SLOW_TESTS to run this test" ) +@pytest.mark.skipif( env_as_bool('RUN_SLOW_TESTS'), reason="Set RUN_SLOW_TESTS to run this test" ) def test_psfex_rendering( psf_palette ): # round_psf_palette ): # psf_palette = round_psf_palette psf = psf_palette.psf diff --git a/tests/models/test_reports.py b/tests/models/test_reports.py index e1e33ead..4303e971 100644 --- a/tests/models/test_reports.py +++ b/tests/models/test_reports.py @@ -11,7 +11,7 @@ from models.base import SmartSession from models.report import Report -from util.util import parse_bool +from util.util import env_as_bool def test_report_bitflags(decam_exposure, decam_reference, decam_default_calibrators): @@ -89,13 +89,14 @@ def test_report_bitflags(decam_exposure, decam_reference, decam_default_calibrat def test_measure_runtime_memory(decam_exposure, decam_reference, pipeline_for_tests, decam_default_calibrators): # make sure we get a random new provenance, not reuse any of the existing data p = pipeline_for_tests + p.subtractor.pars.refset = 'test_refset_decam' + p.pars.save_before_subtraction = True + p.pars.save_at_finish = False p.preprocessor.pars.test_parameter = uuid.uuid4().hex - t0 = time.perf_counter() - try: + t0 = time.perf_counter() ds = p.run(decam_exposure, 'N1') - total_time = time.perf_counter() - t0 assert p.preprocessor.has_recalculated @@ -112,15 +113,15 @@ def test_measure_runtime_memory(decam_exposure, decam_reference, pipeline_for_te peak_memory = 0 for step in ds.runtimes.keys(): # also make sure all the keys are present in both dictionaries measured_time += ds.runtimes[step] - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): peak_memory = max(peak_memory, ds.memory_usages[step]) print(f'total_time: {total_time:.1f}s') print(f'measured_time: {measured_time:.1f}s') - pprint(ds.runtimes, sort_dicts=False) - assert measured_time > 0.98 * total_time # at least 99% of the time is accounted for + pprint(ds.report.process_runtime, sort_dicts=False) + assert measured_time > 0.99 * total_time # at least 99% of the time is accounted for - if parse_bool(os.getenv('SEECHANGE_TRACEMALLOC')): + if env_as_bool('SEECHANGE_TRACEMALLOC'): print(f'peak_memory: {peak_memory:.1f}MB') pprint(ds.memory_usages, sort_dicts=False) assert 1000.0 < peak_memory < 10000.0 # memory usage is in MB, takes between 1 and 10 GB @@ -129,7 +130,9 @@ def test_measure_runtime_memory(decam_exposure, decam_reference, pipeline_for_te rep = session.scalars(sa.select(Report).where(Report.exposure_id == decam_exposure.id)).one() assert rep is not None assert rep.success - assert rep.process_runtime == ds.runtimes + runtimes = rep.process_runtime.copy() + runtimes.pop('reporting') + assert runtimes == ds.runtimes assert rep.process_memory == ds.memory_usages # should contain: 'preprocessing, extraction, subtraction, detection, cutting, measuring' assert rep.progress_steps == ', '.join(PROCESS_OBJECTS.keys()) diff --git a/tests/models/test_source_list.py b/tests/models/test_source_list.py index 36eef87e..36b7df51 100644 --- a/tests/models/test_source_list.py +++ b/tests/models/test_source_list.py @@ -269,8 +269,7 @@ def test_calc_apercor( decam_datastore ): # assert sources.calc_aper_cor( aper_num=2, inf_aper_num=7 ) == pytest.approx( -0.024, abs=0.001 ) -# @pytest.mark.flaky(max_runs=3) -@pytest.mark.skip(reason="We aren't succeeding at controlling garbage collection") +@pytest.mark.skip(reason="This test regularly fails, even when flaky is used. See Issue #263") def test_free( decam_datastore ): ds = decam_datastore ds.get_sources() diff --git a/tests/models/test_world_coordinates.py b/tests/models/test_world_coordinates.py index 7d6821d8..b5834501 100644 --- a/tests/models/test_world_coordinates.py +++ b/tests/models/test_world_coordinates.py @@ -102,17 +102,11 @@ def test_world_coordinates( ztf_datastore_uncommitted, provenance_base, provenan finally: if 'wcobj' in locals(): - # wcobj.delete_from_disk_and_database(session=session) - if sa.inspect(wcobj).persistent: - session.delete(wcobj) - image.wcs = None - image.sources.wcs = None + wcobj.delete_from_disk_and_database(session=session) + if 'wcobj2' in locals(): - # wcobj2.delete_from_disk_and_database(session=session) - if sa.inspect(wcobj2).persistent: - session.delete(wcobj2) - image.wcs = None - image.sources.wcs = None + wcobj2.delete_from_disk_and_database(session=session) + session.commit() if 'image' in locals(): diff --git a/tests/pipeline/test_astro_cal.py b/tests/pipeline/test_astro_cal.py index 7d527791..81b7e683 100644 --- a/tests/pipeline/test_astro_cal.py +++ b/tests/pipeline/test_astro_cal.py @@ -13,6 +13,10 @@ from models.image import Image from models.world_coordinates import WorldCoordinates +from util.util import env_as_bool + +from tests.conftest import SKIP_WARNING_TESTS + def test_solve_wcs_scamp_failures( ztf_gaia_dr3_excerpt, ztf_datastore_uncommitted, astrometor ): catexp = ztf_gaia_dr3_excerpt @@ -48,7 +52,7 @@ def test_solve_wcs_scamp( ztf_gaia_dr3_excerpt, ztf_datastore_uncommitted, astro ds = ztf_datastore_uncommitted # Make True for visual testing purposes - if os.getenv('INTERACTIVE', False): + if env_as_bool('INTERACTIVE'): basename = os.path.join(CODE_ROOT, 'tests/plots') catexp.ds9_regfile( os.path.join(basename, 'catexp.reg'), radius=4 ) ds.sources.ds9_regfile( os.path.join(basename, 'sources.reg'), radius=3 ) @@ -184,12 +188,13 @@ def test_run_scamp( decam_datastore, astrometor ): def test_warnings_and_exceptions(decam_datastore, astrometor): - astrometor.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - astrometor.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'astro_cal'." in str(w.message) for w in record) + if not SKIP_WARNING_TESTS: + astrometor.pars.inject_warnings = 1 + with pytest.warns(UserWarning) as record: + astrometor.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'astro_cal'." in str(w.message) for w in record) astrometor.pars.inject_warnings = 0 astrometor.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_backgrounding.py b/tests/pipeline/test_backgrounding.py index 0d995ad5..c8a3d1f9 100644 --- a/tests/pipeline/test_backgrounding.py +++ b/tests/pipeline/test_backgrounding.py @@ -5,6 +5,8 @@ from improc.tools import sigma_clipping +from tests.conftest import SKIP_WARNING_TESTS + def test_measuring_background(decam_processed_image, backgrounder): backgrounder.pars.test_parameter = uuid.uuid4().hex # make sure there is no hashed value @@ -36,12 +38,13 @@ def test_measuring_background(decam_processed_image, backgrounder): def test_warnings_and_exceptions(decam_datastore, backgrounder): - backgrounder.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + backgrounder.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - backgrounder.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'backgrounding'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + backgrounder.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'backgrounding'." in str(w.message) for w in record) backgrounder.pars.inject_warnings = 0 backgrounder.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_coaddition.py b/tests/pipeline/test_coaddition.py index 700fddeb..24419b8d 100644 --- a/tests/pipeline/test_coaddition.py +++ b/tests/pipeline/test_coaddition.py @@ -9,6 +9,7 @@ from models.image import Image from models.source_list import SourceList from models.psf import PSF +from models.background import Background from models.world_coordinates import WorldCoordinates from models.zero_point import ZeroPoint @@ -408,7 +409,7 @@ def test_coaddition_pipeline_inputs(ptf_reference_images): ) im_ids = set([im.id for im in pipe.images]) ptf_im_ids = set([im.id for im in ptf_reference_images]) - assert ptf_im_ids == im_ids + assert ptf_im_ids.issubset(im_ids) ptf_ras = [im.ra for im in ptf_reference_images] ptf_decs = [im.dec for im in ptf_reference_images] @@ -429,7 +430,7 @@ def test_coaddition_pipeline_inputs(ptf_reference_images): im_ids = set([im.id for im in pipe.images]) ptf_im_ids = set([im.id for im in ptf_reference_images]) - assert ptf_im_ids == im_ids + assert ptf_im_ids.issubset(im_ids) def test_coaddition_pipeline_outputs(ptf_reference_images, ptf_aligned_images): @@ -493,6 +494,7 @@ def test_coadded_reference(ptf_ref): assert ref_image.type == 'ComSci' assert isinstance(ref_image.sources, SourceList) assert isinstance(ref_image.psf, PSF) + assert isinstance(ref_image.bg, Background) assert isinstance(ref_image.wcs, WorldCoordinates) assert isinstance(ref_image.zp, ZeroPoint) @@ -500,11 +502,8 @@ def test_coadded_reference(ptf_ref): assert ptf_ref.filter == ref_image.filter assert ptf_ref.section_id == ref_image.section_id - assert ptf_ref.validity_start is None - assert ptf_ref.validity_end is None - assert ptf_ref.provenance.upstreams[0].id == ref_image.provenance_id - assert ptf_ref.provenance.process == 'reference' + assert ptf_ref.provenance.process == 'referencing' assert ptf_ref.provenance.parameters['test_parameter'] == 'test_value' diff --git a/tests/pipeline/test_compare_sextractor_to_photutils.py b/tests/pipeline/test_compare_sextractor_to_photutils.py index be4df175..a1720ec2 100644 --- a/tests/pipeline/test_compare_sextractor_to_photutils.py +++ b/tests/pipeline/test_compare_sextractor_to_photutils.py @@ -10,9 +10,10 @@ from improc.sextrsky import sextrsky from util.logger import SCLogger +from util.util import env_as_bool -@pytest.mark.skipif( os.getenv('INTERACTIVE') is None, reason='Set INTERACTIVE to run this test' ) +@pytest.mark.skipif( not env_as_bool('INTERACTIVE'), reason='Set INTERACTIVE to run this test' ) def test_compare_sextr_photutils( decam_datastore ): plot_dir = os.path.join(CODE_ROOT, 'tests/plots/sextractor_comparison') os.makedirs( plot_dir, exist_ok=True) diff --git a/tests/pipeline/test_cutting.py b/tests/pipeline/test_cutting.py index e79b5334..78abd3a7 100644 --- a/tests/pipeline/test_cutting.py +++ b/tests/pipeline/test_cutting.py @@ -1,13 +1,16 @@ import pytest +from tests.conftest import SKIP_WARNING_TESTS + def test_warnings_and_exceptions(decam_datastore, cutter): - cutter.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + cutter.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - cutter.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'cutting'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + cutter.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'cutting'." in str(w.message) for w in record) cutter.pars.inject_warnings = 0 cutter.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_detection.py b/tests/pipeline/test_detection.py index 2965c86f..079a0c17 100644 --- a/tests/pipeline/test_detection.py +++ b/tests/pipeline/test_detection.py @@ -3,12 +3,13 @@ import matplotlib.pyplot as plt import scipy.signal -from astropy.io import fits from astropy.coordinates import SkyCoord import astropy.units as u from improc.tools import sigma_clipping, make_gaussian, make_cutouts +from tests.conftest import SKIP_WARNING_TESTS + # os.environ['INTERACTIVE'] = '1' # for diagnostics only CUTOUT_SIZE = 15 @@ -153,12 +154,13 @@ def test_detection_ptf_supernova(detector, ptf_subtraction1, blocking_plots, cac def test_warnings_and_exceptions(decam_datastore, detector): - detector.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + detector.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - detector.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'detection'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + detector.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'detection'." in str(w.message) for w in record) detector.pars.inject_warnings = 0 detector.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_extraction.py b/tests/pipeline/test_extraction.py index 9bd472c0..f9bbe94a 100644 --- a/tests/pipeline/test_extraction.py +++ b/tests/pipeline/test_extraction.py @@ -8,14 +8,13 @@ import random import numpy as np -import sqlalchemy as sa from astropy.io import votable from models.base import SmartSession, FileOnDiskMixin, get_archive_object, CODE_ROOT from models.provenance import Provenance -from models.image import Image -from models.source_list import SourceList + +from tests.conftest import SKIP_WARNING_TESTS def test_sep_find_sources_in_small_image(decam_small_image, extractor, blocking_plots): @@ -318,12 +317,13 @@ def test_extract_sources_sextractor( decam_datastore, extractor, provenance_base def test_warnings_and_exceptions(decam_datastore, extractor): - extractor.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + extractor.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - extractor.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'detection'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + extractor.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'detection'." in str(w.message) for w in record) extractor.pars.inject_warnings = 0 extractor.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_making_references.py b/tests/pipeline/test_making_references.py new file mode 100644 index 00000000..ea9e0e55 --- /dev/null +++ b/tests/pipeline/test_making_references.py @@ -0,0 +1,258 @@ +import time + +import pytest +import uuid + +import numpy as np + +import sqlalchemy as sa + +from pipeline.ref_maker import RefMaker + +from models.base import SmartSession +from models.provenance import Provenance +from models.reference import Reference +from models.refset import RefSet + + +def add_test_parameters(maker): + """Utility function to add "test_parameter" to all the underlying objects. """ + for name in ['preprocessor', 'extractor', 'backgrounder', 'astrometor', 'photometor', 'coadder']: + for pipe in ['pipeline', 'coadd_pipeline']: + obj = getattr(getattr(maker, pipe), name, None) + if obj is not None: + obj.pars._enforce_no_new_attrs = False + obj.pars.test_parameter = obj.pars.add_par( + 'test_parameter', 'test_value', str, 'A parameter showing this is part of a test', critical=True, + ) + obj.pars._enforce_no_new_attrs = True + + +def test_finding_references(ptf_ref): + with pytest.raises(ValueError, match='Must provide both'): + ref = Reference.get_references(ra=188) + with pytest.raises(ValueError, match='Must provide both'): + ref = Reference.get_references(dec=4.5) + with pytest.raises(ValueError, match='Must provide both'): + ref = Reference.get_references(target='foo') + with pytest.raises(ValueError, match='Must provide both'): + ref = Reference.get_references(section_id='bar') + with pytest.raises(ValueError, match='Must provide both'): + ref = Reference.get_references(ra=188, section_id='bar') + with pytest.raises(ValueError, match='Must provide both'): + ref = Reference.get_references(dec=4.5, target='foo') + with pytest.raises(ValueError, match='Must provide either ra and dec, or target and section_id'): + ref = Reference.get_references() + with pytest.raises(ValueError, match='Cannot provide target/section_id and also ra/dec! '): + ref = Reference.get_references(ra=188, dec=4.5, target='foo', section_id='bar') + + ref = Reference.get_references(ra=188, dec=4.5) + assert len(ref) == 1 + assert ref[0].id == ptf_ref.id + + ref = Reference.get_references(ra=188, dec=4.5, provenance_ids=ptf_ref.provenance_id) + assert len(ref) == 1 + assert ref[0].id == ptf_ref.id + + ref = Reference.get_references(ra=0, dec=0) + assert len(ref) == 0 + + ref = Reference.get_references(target='foo', section_id='bar') + assert len(ref) == 0 + + ref = Reference.get_references(ra=180, dec=4.5, provenance_ids=['foo', 'bar']) + assert len(ref) == 0 + + +def test_making_refsets(): + # make a new refset with a new name + name = uuid.uuid4().hex + maker = RefMaker(maker={'name': name, 'instruments': ['PTF']}) + min_number = maker.pars.min_number + max_number = maker.pars.max_number + + # we still haven't run the maker, so everything is empty + assert maker.im_provs is None + assert maker.ex_provs is None + assert maker.coadd_im_prov is None + assert maker.coadd_ex_prov is None + assert maker.ref_upstream_hash is None + + new_ref = maker.run(ra=0, dec=0, filter='R') + assert new_ref is None # cannot find a specific reference here + refset = maker.refset + + assert refset is not None # can produce a reference set without finding a reference + assert all(isinstance(p, Provenance) for p in maker.im_provs) + assert all(isinstance(p, Provenance) for p in maker.ex_provs) + assert isinstance(maker.coadd_im_prov, Provenance) + assert isinstance(maker.coadd_ex_prov, Provenance) + + up_hash1 = refset.upstream_hash + assert maker.ref_upstream_hash == up_hash1 + assert isinstance(up_hash1, str) + assert len(up_hash1) == 20 + assert len(refset.provenances) == 1 + assert refset.provenances[0].parameters['min_number'] == min_number + assert refset.provenances[0].parameters['max_number'] == max_number + assert 'name' not in refset.provenances[0].parameters # not a critical parameter! + assert 'description' not in refset.provenances[0].parameters # not a critical parameter! + + # now make a change to the maker's parameters (not the data production parameters) + maker.pars.min_number = min_number + 5 + maker.pars.allow_append = False # this should prevent us from appending to the existing ref-set + + with pytest.raises( + RuntimeError, match='Found a RefSet with the name .*, but it has a different provenance!' + ): + new_ref = maker.run(ra=0, dec=0, filter='R') + + maker.pars.allow_append = True # now it should be ok + new_ref = maker.run(ra=0, dec=0, filter='R') + assert new_ref is None # still can't find images there + + refset = maker.refset + up_hash2 = refset.upstream_hash + assert up_hash1 == up_hash2 # the underlying data MUST be the same + assert len(refset.provenances) == 2 + assert refset.provenances[0].parameters['min_number'] == min_number + assert refset.provenances[1].parameters['min_number'] == min_number + 5 + assert refset.provenances[0].parameters['max_number'] == max_number + assert refset.provenances[1].parameters['max_number'] == max_number + + # now try to make a new ref-set with a different name + name2 = uuid.uuid4().hex + maker.pars.name = name2 + new_ref = maker.run(ra=0, dec=0, filter='R') + assert new_ref is None # still can't find images there + + refset2 = maker.refset + assert len(refset2.provenances) == 1 + assert refset2.provenances[0].id == refset.provenances[1].id # these ref-sets share the same provenance! + + # now try to append with different data parameters: + maker.pipeline.extractor.pars['threshold'] = 3.14 + + with pytest.raises( + RuntimeError, match='Found a RefSet with the name .*, but it has a different upstream_hash!' + ): + new_ref = maker.run(ra=0, dec=0, filter='R') + + +def test_making_references(ptf_reference_images): + name = uuid.uuid4().hex + ref = None + ref5 = None + + try: + maker = RefMaker( + maker={ + 'name': name, + 'instruments': ['PTF'], + 'min_number': 4, + 'max_number': 10, + 'end_time': '2010-01-01', + } + ) + add_test_parameters(maker) # make sure we have a test parameter on everything + maker.coadd_pipeline.coadder.pars.test_parameter = uuid.uuid4().hex # do not load an existing image + + t0 = time.perf_counter() + ref = maker.run(ra=188, dec=4.5, filter='R') + first_time = time.perf_counter() - t0 + first_refset = maker.refset + first_image = ref.image + assert ref is not None + + # check that this ref is saved to the DB + with SmartSession() as session: + loaded_ref = session.scalars(sa.select(Reference).where(Reference.id == ref.id)).first() + assert loaded_ref is not None + + # now try to make a new ref with the same parameters + t0 = time.perf_counter() + ref2 = maker.run(ra=188, dec=4.5, filter='R') + second_time = time.perf_counter() - t0 + second_refset = maker.refset + second_image = ref2.image + assert second_time < first_time * 0.1 # should be much faster, we are reloading the reference set + assert ref2.id == ref.id + assert second_refset.id == first_refset.id + assert second_image.id == first_image.id + + # now try to make a new ref set with a new name + maker.pars.name = uuid.uuid4().hex + t0 = time.perf_counter() + ref3 = maker.run(ra=188, dec=4.5, filter='R') + third_time = time.perf_counter() - t0 + third_refset = maker.refset + third_image = ref3.image + assert third_time < first_time * 0.1 # should be faster, we are loading the same reference + assert third_refset.id != first_refset.id + assert ref3.id == ref.id + assert third_image.id == first_image.id + + # append to the same refset but with different reference parameters (image loading parameters) + maker.pars.max_number += 1 + t0 = time.perf_counter() + ref4 = maker.run(ra=188, dec=4.5, filter='R') + fourth_time = time.perf_counter() - t0 + fourth_refset = maker.refset + fourth_image = ref4.image + assert fourth_time < first_time * 0.1 # should be faster, we can still re-use the underlying coadd image + assert fourth_refset.id != first_refset.id + assert ref4.id != ref.id + assert fourth_image.id == first_image.id + + # now make the coadd image again with a different parameter for the data production + maker.coadd_pipeline.coadder.pars.flag_fwhm_factor *= 1.2 + maker.pars.name = uuid.uuid4().hex # MUST give a new name, otherwise it will not allow the new data parameters + t0 = time.perf_counter() + ref5 = maker.run(ra=188, dec=4.5, filter='R') + fifth_time = time.perf_counter() - t0 + fifth_refset = maker.refset + fifth_image = ref5.image + assert np.log10(fifth_time) == pytest.approx(np.log10(first_time), rel=0.2) # should take about the same time + assert ref5.id != ref.id + assert fifth_refset.id != first_refset.id + assert fifth_image.id != first_image.id + + finally: # cleanup + if ref is not None and ref.image is not None: + ref.image.delete_from_disk_and_database(remove_downstreams=True) + + # we don't have to delete ref2, ref3, ref4, because they depend on the same coadd image, and cascade should + # destroy them as soon as the coadd is removed + + if ref5 is not None and ref5.image is not None: + ref5.image.delete_from_disk_and_database(remove_downstreams=True) + + +def test_datastore_get_reference(ptf_datastore, ptf_ref, ptf_ref_offset): + with SmartSession() as session: + refset = session.scalars(sa.select(RefSet).where(RefSet.name == 'test_refset_ptf')).first() + assert refset is not None + assert len(refset.provenances) == 1 + assert refset.provenances[0].id == ptf_ref.provenance_id + + # append the newer reference to the refset + ptf_ref_offset = session.merge(ptf_ref_offset) + refset.provenances.append(ptf_ref_offset.provenance) + session.commit() + + ref = ptf_datastore.get_reference(provenances=refset.provenances, session=session) + + assert ref is not None + assert ref.id == ptf_ref.id + + # now offset the image that needs matching + ptf_datastore.image.ra_corner_00 -= 0.5 + ptf_datastore.image.ra_corner_01 -= 0.5 + ptf_datastore.image.ra_corner_10 -= 0.5 + ptf_datastore.image.ra_corner_11 -= 0.5 + + ref = ptf_datastore.get_reference(provenances=refset.provenances, session=session) + + assert ref is not None + assert ref.id == ptf_ref_offset.id + diff --git a/tests/pipeline/test_measuring.py b/tests/pipeline/test_measuring.py index 0b1a6487..d5c7b0a9 100644 --- a/tests/pipeline/test_measuring.py +++ b/tests/pipeline/test_measuring.py @@ -9,6 +9,8 @@ from improc.tools import make_gaussian +from tests.conftest import SKIP_WARNING_TESTS + @pytest.mark.flaky(max_runs=3) def test_measuring(measurer, decam_cutouts, decam_default_calibrators): @@ -219,12 +221,13 @@ def test_measuring(measurer, decam_cutouts, decam_default_calibrators): def test_warnings_and_exceptions(decam_datastore, measurer): - measurer.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + measurer.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - measurer.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'measuring'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + measurer.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'measuring'." in str(w.message) for w in record) measurer.pars.inject_exceptions = 1 measurer.pars.inject_warnings = 0 diff --git a/tests/pipeline/test_photo_cal.py b/tests/pipeline/test_photo_cal.py index ee274a61..cae423ea 100644 --- a/tests/pipeline/test_photo_cal.py +++ b/tests/pipeline/test_photo_cal.py @@ -8,6 +8,8 @@ from models.base import CODE_ROOT +from tests.conftest import SKIP_WARNING_TESTS + # os.environ['INTERACTIVE'] = '1' # for diagnostics only @@ -66,12 +68,13 @@ def test_decam_photo_cal( decam_datastore, photometor, blocking_plots ): def test_warnings_and_exceptions(decam_datastore, photometor): - photometor.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + photometor.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - photometor.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'photo_cal'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + photometor.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'photo_cal'." in str(w.message) for w in record) photometor.pars.inject_warnings = 0 photometor.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 556233ca..221f7322 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -17,10 +17,10 @@ from models.measurements import Measurements from models.report import Report -from util.logger import SCLogger - from pipeline.top_level import Pipeline +from tests.conftest import SKIP_WARNING_TESTS + def check_datastore_and_database_have_everything(exp_id, sec_id, ref_id, session, ds): """ @@ -170,16 +170,6 @@ def test_parameters( test_config ): 'cutting': { 'cutout_size': 666 }, 'measuring': { 'outlier_sigma': 3.5 } } - pipelinemodule = { - 'preprocessing': 'preprocessor', - 'extraction': 'extractor', - 'astro_cal': 'astrometor', - 'photo_cal': 'photometor', - 'subtraction': 'subtractor', - 'detection': 'detector', - 'cutting': 'cutter', - 'measuring': 'measurer' - } def check_override( new_values_dict, pars ): for key, value in new_values_dict.items(): @@ -199,15 +189,30 @@ def check_override( new_values_dict, pars ): assert check_override(overrides['measuring'], pipeline.measurer.pars) -def test_data_flow(decam_exposure, decam_reference, decam_default_calibrators, archive): +def test_running_without_reference(decam_exposure, decam_refset, decam_default_calibrators, pipeline_for_tests): + p = pipeline_for_tests + p.subtractor.pars.refset = 'test_refset_decam' # pointing out this ref set doesn't mean we have an actual reference + p.pars.save_before_subtraction = True # need this so images get saved even though it crashes on "no reference" + + with pytest.raises(ValueError, match='Cannot find a reference image corresponding to.*'): + ds = p.run(decam_exposure, 'N1') + ds.reraise() + + # make sure the data is saved + with SmartSession() as session: + im = session.scalars(sa.select(Image).where(Image.id == ds.image.id)).first() + assert im is not None + + +def test_data_flow(decam_exposure, decam_reference, decam_default_calibrators, pipeline_for_tests, archive): """Test that the pipeline runs end-to-end.""" exposure = decam_exposure ref = decam_reference sec_id = ref.section_id try: # cleanup the file at the end - p = Pipeline() - p.pars.save_before_subtraction = False + p = pipeline_for_tests + p.subtractor.pars.refset = 'test_refset_decam' assert p.extractor.pars.threshold != 3.14 assert p.detector.pars.threshold != 3.14 @@ -274,7 +279,7 @@ def test_data_flow(decam_exposure, decam_reference, decam_default_calibrators, a shutil.rmtree(os.path.join(archive.test_folder_path, '115'), ignore_errors=True) -def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_calibrators, archive): +def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_calibrators, pipeline_for_tests, archive): """ Test that adding a bitflag to the exposure propagates to all downstreams as they are created Does not check measurements, as they do not have the HasBitflagBadness Mixin. @@ -285,6 +290,7 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali try: # cleanup the file at the end p = Pipeline() + p.subtractor.pars.refset = 'test_refset_decam' p.pars.save_before_subtraction = False exposure.badness = 'banding' # add a bitflag to check for propagation @@ -301,6 +307,8 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali assert ds.sub_image._upstream_bitflag == 2 assert ds.detections._upstream_bitflag == 2 assert ds.cutouts._upstream_bitflag == 2 + for m in ds.measurements: + assert m._upstream_bitflag == 2 # test part 2: Add a second bitflag partway through and check it propagates to downstreams @@ -325,7 +333,7 @@ def test_bitflag_propagation(decam_exposure, decam_reference, decam_default_cali assert ds.cutouts._upstream_bitflag == desired_bitflag for m in ds.measurements: assert m._upstream_bitflag == desired_bitflag - assert ds.image.bitflag == 2 # not in the downstream of sources + assert ds.image.bitflag == 2 # not in the downstream of sources # test part 3: test update_downstream_badness() function by adding and removing flags # and observing propagation @@ -397,6 +405,7 @@ def test_get_upstreams_and_downstreams(decam_exposure, decam_reference, decam_de try: # cleanup the file at the end p = Pipeline() + p.subtractor.pars.refset = 'test_refset_decam' ds = p.run(exposure, sec_id) # commit to DB using this session @@ -506,8 +515,10 @@ def test_datastore_delete_everything(decam_datastore): ).first() is None -def test_provenance_tree(pipeline_for_tests, decam_exposure, decam_datastore, decam_reference): +def test_provenance_tree(pipeline_for_tests, decam_refset, decam_exposure, decam_datastore, decam_reference): p = pipeline_for_tests + p.subtractor.pars.refset = 'test_refset_decam' + provs = p.make_provenance_tree(decam_exposure) assert isinstance(provs, dict) @@ -538,80 +549,90 @@ def test_provenance_tree(pipeline_for_tests, decam_exposure, decam_datastore, de def test_inject_warnings_errors(decam_datastore, decam_reference, pipeline_for_tests): from pipeline.top_level import PROCESS_OBJECTS p = pipeline_for_tests - obj_to_process_name = { - 'preprocessor': 'preprocessing', - 'extractor': 'detection', - 'backgrounder': 'backgrounding', - 'astrometor': 'astro_cal', - 'photometor': 'photo_cal', - 'subtractor': 'subtraction', - 'detector': 'detection', - 'cutter': 'cutting', - 'measurer': 'measuring', - } - for process, objects in PROCESS_OBJECTS.items(): - if isinstance(objects, str): - objects = [objects] - elif isinstance(objects, dict): - objects = list(set(objects.values())) # e.g., "extractor", "astrometor", "photometor" - - # first reset all warnings and errors - for obj in objects: - for _, objects2 in PROCESS_OBJECTS.items(): - if isinstance(objects2, str): - objects2 = [objects2] - elif isinstance(objects2, dict): - objects2 = list(set(objects2.values())) # e.g., "extractor", "astrometor", "photometor" - for obj2 in objects2: - getattr(p, obj2).pars.inject_exceptions = False - getattr(p, obj2).pars.inject_warnings = False - - # set the warning: - getattr(p, obj).pars.inject_warnings = True - - # run the pipeline - ds = p.run(decam_datastore) - expected = (f"{process}: Warning injected by pipeline parameters " - f"in process '{obj_to_process_name[obj]}'") - assert expected in ds.report.warnings - - # these are used to find the report later on - exp_id = ds.exposure_id - sec_id = ds.section_id - prov_id = ds.report.provenance_id - - # set the error instead - getattr(p, obj).pars.inject_warnings = False - getattr(p, obj).pars.inject_exceptions = True - # run the pipeline again, this time with an exception - - with pytest.raises( - RuntimeError, - match=f"Exception injected by pipeline parameters in process '{obj_to_process_name[obj]}'" - ): - ds = p.run(decam_datastore) - - # fetch the report object - with SmartSession() as session: - reports = session.scalars( - sa.select(Report).where( - Report.exposure_id == exp_id, - Report.section_id == sec_id, - Report.provenance_id == prov_id - ).order_by(Report.start_time.desc()) - ).all() - report = reports[0] # the last report is the one we just generated - assert len(reports) - 1 == report.num_prev_reports - assert not report.success - assert report.error_step == process - assert report.error_type == 'RuntimeError' - assert 'Exception injected by pipeline parameters' in report.error_message + p.subtractor.pars.refset = 'test_refset_decam' + + try: + obj_to_process_name = { + 'preprocessor': 'preprocessing', + 'extractor': 'detection', + 'backgrounder': 'backgrounding', + 'astrometor': 'astro_cal', + 'photometor': 'photo_cal', + 'subtractor': 'subtraction', + 'detector': 'detection', + 'cutter': 'cutting', + 'measurer': 'measuring', + } + for process, objects in PROCESS_OBJECTS.items(): + if isinstance(objects, str): + objects = [objects] + elif isinstance(objects, dict): + objects = list(set(objects.values())) # e.g., "extractor", "astrometor", "photometor" + + # first reset all warnings and errors + for obj in objects: + for _, objects2 in PROCESS_OBJECTS.items(): + if isinstance(objects2, str): + objects2 = [objects2] + elif isinstance(objects2, dict): + objects2 = list(set(objects2.values())) # e.g., "extractor", "astrometor", "photometor" + for obj2 in objects2: + getattr(p, obj2).pars.inject_exceptions = False + getattr(p, obj2).pars.inject_warnings = False + + if not SKIP_WARNING_TESTS: + # set the warning: + getattr(p, obj).pars.inject_warnings = True + + # run the pipeline + ds = p.run(decam_datastore) + expected = (f"{process}: Warning injected by pipeline parameters " + f"in process '{obj_to_process_name[obj]}'") + assert expected in ds.report.warnings + + # these are used to find the report later on + exp_id = ds.exposure_id + sec_id = ds.section_id + prov_id = ds.report.provenance_id + + # set the error instead + getattr(p, obj).pars.inject_warnings = False + getattr(p, obj).pars.inject_exceptions = True + # run the pipeline again, this time with an exception + + with pytest.raises( + RuntimeError, + match=f"Exception injected by pipeline parameters in process '{obj_to_process_name[obj]}'" + ): + ds = p.run(decam_datastore) + ds.reraise() + + # fetch the report object + with SmartSession() as session: + reports = session.scalars( + sa.select(Report).where( + Report.exposure_id == exp_id, + Report.section_id == sec_id, + Report.provenance_id == prov_id + ).order_by(Report.start_time.desc()) + ).all() + report = reports[0] # the last report is the one we just generated + assert len(reports) - 1 == report.num_prev_reports + assert not report.success + assert report.error_step == process + assert report.error_type == 'RuntimeError' + assert 'Exception injected by pipeline parameters' in report.error_message + + finally: + if 'ds' in locals(): + ds.read_exception() + ds.delete_everything() def test_multiprocessing_make_provenances_and_exposure(decam_exposure, decam_reference, pipeline_for_tests): from multiprocessing import SimpleQueue, Process process_list = [] - + pipeline_for_tests.subtractor.pars.refset = 'test_refset_decam' def make_provenances(exposure, pipeline, queue): provs = pipeline.make_provenance_tree(exposure) queue.put(provs) diff --git a/tests/pipeline/test_preprocessing.py b/tests/pipeline/test_preprocessing.py index 0095042d..e956e90c 100644 --- a/tests/pipeline/test_preprocessing.py +++ b/tests/pipeline/test_preprocessing.py @@ -3,12 +3,13 @@ import uuid import numpy as np -import sqlalchemy as sa from astropy.io import fits from models.base import FileOnDiskMixin, SmartSession from models.image import Image +from tests.conftest import SKIP_WARNING_TESTS + def test_preprocessing( provenance_decam_prep, decam_exposure, test_config, preprocessor, decam_default_calibrators, archive @@ -91,12 +92,13 @@ def test_preprocessing( def test_warnings_and_exceptions(decam_exposure, preprocessor, decam_default_calibrators, archive): - preprocessor.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + preprocessor.pars.inject_warnings = 1 - with pytest.warns(UserWarning) as record: - preprocessor.run(decam_exposure, 'N1') - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'preprocessing'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + preprocessor.run(decam_exposure, 'N1') + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'preprocessing'." in str(w.message) for w in record) preprocessor.pars.inject_warnings = 0 preprocessor.pars.inject_exceptions = 1 diff --git a/tests/pipeline/test_reffinding.py b/tests/pipeline/test_reffinding.py deleted file mode 100644 index 4f4ce2c1..00000000 --- a/tests/pipeline/test_reffinding.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -import math -from models.image import Image -from pipeline.data_store import DataStore - -def test_data_store_overlap_frac(): - dra = 0.75 - ddec = 0.375 - radec1 = [ ( 10., -3. ), ( 10., -45. ) , (10., -80.) ] - - # TODO : add tests where things aren't perfectly square - for ra, dec in radec1: - cd = math.cos( dec * math.pi / 180. ) - i1 = Image( ra = ra, dec = dec, - ra_corner_00 = ra - dra/2. / cd, - ra_corner_01 = ra - dra/2. / cd, - ra_corner_10 = ra + dra/2. / cd, - ra_corner_11 = ra + dra/2. / cd, - dec_corner_00 = dec - ddec/2., - dec_corner_10 = dec - ddec/2., - dec_corner_01 = dec + ddec/2., - dec_corner_11 = dec + ddec/2. ) - for frac, offx, offy in [ ( 1. , 0. , 0. ), - ( 0.5 , 0.5, 0. ), - ( 0.5 , -0.5, 0. ), - ( 0.5 , 0. , 0.5 ), - ( 0.5 , 0. , -0.5 ), - ( 0.25, 0.5, 0.5 ), - ( 0.25, -0.5, 0.5 ), - ( 0.25, 0.5, -0.5 ), - ( 0.25, -0.5, -0.5 ), - ( 0., 1., 0. ), - ( 0., -1., 0. ), - ( 0., 1., 0. ), - ( 0., -1., 0. ), - ( 0., -1., -1. ), - ( 0., 1., -1. ) ]: - ra2 = ra + offx * dra / cd - dec2 = dec + offy * ddec - i2 = Image( ra = ra2, dec = dec2, - ra_corner_00 = ra2 - dra/2. / cd, - ra_corner_01 = ra2 - dra/2. / cd, - ra_corner_10 = ra2 + dra/2. / cd, - ra_corner_11 = ra2 + dra/2. / cd, - dec_corner_00 = dec2 - ddec/2., - dec_corner_10 = dec2 - ddec/2., - dec_corner_01 = dec2 + ddec/2., - dec_corner_11 = dec2 + ddec/2. ) - assert DataStore._overlap_frac( i1, i2 ) == pytest.approx( frac, abs=0.01 ) - - diff --git a/tests/pipeline/test_subtraction.py b/tests/pipeline/test_subtraction.py index 5998802c..7fdd1699 100644 --- a/tests/pipeline/test_subtraction.py +++ b/tests/pipeline/test_subtraction.py @@ -6,6 +6,8 @@ from improc.tools import sigma_clipping +from tests.conftest import SKIP_WARNING_TESTS + def test_subtraction_data_products(ptf_ref, ptf_supernova_images, subtractor): assert len(ptf_supernova_images) == 2 @@ -19,8 +21,10 @@ def test_subtraction_data_products(ptf_ref, ptf_supernova_images, subtractor): # run the subtraction like you'd do in the real pipeline (calls get_reference and get_subtraction internally) subtractor.pars.test_parameter = uuid.uuid4().hex subtractor.pars.method = 'naive' + subtractor.pars.refset = 'test_refset_ptf' assert subtractor.pars.alignment['to_index'] == 'new' # make sure alignment is configured to new, not latest image ds = subtractor.run(image1) + ds.reraise() # make sure there are no exceptions from run() # check that we don't lazy load a subtracted image, but recalculate it assert subtractor.has_recalculated @@ -50,8 +54,10 @@ def test_subtraction_ptf_zogy(ptf_ref, ptf_supernova_images, subtractor): # run the subtraction like you'd do in the real pipeline (calls get_reference and get_subtraction internally) subtractor.pars.test_parameter = uuid.uuid4().hex subtractor.pars.method = 'zogy' # this is the default, but it might not always be + subtractor.pars.refset = 'test_refset_ptf' assert subtractor.pars.alignment['to_index'] == 'new' # make sure alignment is configured to new, not latest image ds = subtractor.run(image1) + ds.reraise() # make sure there are no exceptions from run() assert ds.sub_image is not None assert ds.sub_image.data is not None @@ -77,12 +83,14 @@ def test_subtraction_ptf_zogy(ptf_ref, ptf_supernova_images, subtractor): def test_warnings_and_exceptions(decam_datastore, decam_reference, subtractor, decam_default_calibrators): - subtractor.pars.inject_warnings = 1 + if not SKIP_WARNING_TESTS: + subtractor.pars.inject_warnings = 1 + subtractor.pars.refset = 'test_refset_decam' - with pytest.warns(UserWarning) as record: - subtractor.run(decam_datastore) - assert len(record) > 0 - assert any("Warning injected by pipeline parameters in process 'subtraction'." in str(w.message) for w in record) + with pytest.warns(UserWarning) as record: + subtractor.run(decam_datastore) + assert len(record) > 0 + assert any("Warning injected by pipeline parameters in process 'subtraction'." in str(w.message) for w in record) subtractor.pars.inject_warnings = 0 subtractor.pars.inject_exceptions = 1 @@ -90,4 +98,4 @@ def test_warnings_and_exceptions(decam_datastore, decam_reference, subtractor, d ds = subtractor.run(decam_datastore) ds.reraise() assert "Exception injected by pipeline parameters in process 'subtraction'." in str(excinfo.value) - ds.read_exception() \ No newline at end of file + ds.read_exception() diff --git a/tests/util/test_radec.py b/tests/util/test_radec.py index edd1429a..65a314e6 100644 --- a/tests/util/test_radec.py +++ b/tests/util/test_radec.py @@ -2,6 +2,7 @@ from util import radec + def test_parse_sexigesimal_degrees(): deg = radec.parse_sexigesimal_degrees( '15:32:25' ) assert deg == pytest.approx( 15.54027778, abs=1e-8 ) @@ -26,6 +27,15 @@ def test_parse_sexigesimal_degrees(): deg = radec.parse_sexigesimal_degrees( '-00:30:00', hours=True, positive=False ) assert deg == -7.5 + # make sure it fails on bad inputs: + with pytest.raises( ValueError, match='Error parsing'): + radec.parse_sexigesimal_degrees( '12:30:36:00' ) + with pytest.raises( ValueError, match='Error parsing'): + radec.parse_sexigesimal_degrees( '12:30' ) + with pytest.raises( ValueError, match='Error parsing'): + radec.parse_sexigesimal_degrees( 'foobar' ) + + def test_radec_to_gal_and_eclip(): gal_l, gal_b, ecl_lon, ecl_lat = radec.radec_to_gal_and_eclip( 210.53, -32.3 ) assert gal_l == pytest.approx( 319.86357776, abs=1e-8 ) diff --git a/util/radec.py b/util/radec.py index 9a156772..5c96e6af 100644 --- a/util/radec.py +++ b/util/radec.py @@ -9,7 +9,7 @@ ' *(?P[0-9]{1,2}(\.[0-9]*)?) *$' ) -def parse_sexigesimal_degrees( strval, hours=False, **kwargs ): +def parse_sexigesimal_degrees( strval, hours=False, positive=None ): """Parse [+-]dd:mm::ss to decimal degrees in the range [0, 360) or (-180, 180] Parameters @@ -28,16 +28,12 @@ def parse_sexigesimal_degrees( strval, hours=False, **kwargs ): float, the value in degrees """ - - keys = list( kwargs.keys() ) - if ( keys != [ 'positive' ] ) and ( keys != [] ): - raise RuntimeError( f'parse_sexigesimal_degrees: unknown keyword arguments ' - f'{[ k for k in keys if k != "positive"]}' ) - positive = kwargs['positive'] if 'positive' in keys else hours + if positive is None: + positive = hours match = _radecparse.search( strval ) if match is None: - raise RuntimeError( f"Error parsing {strval} for [+-]dd:mm::ss" ) + raise ValueError( f"Error parsing {strval} for [+-]dd:mm::ss" ) val = float(match.group('d')) + float(match.group('m'))/60. + float(match.group('s'))/3600. val *= -1 if match.group('sign') == '-' else 1 val *= 15. if hours else 1. @@ -74,6 +70,7 @@ def radec_to_gal_and_eclip( ra, dec ): return ( gal_l, gal_b, ecl_lon, ecl_lat ) + def parse_ra_deg_to_hms(ra): """ Convert an RA in degrees to a string in sexagesimal format (in hh:mm:ss). @@ -83,6 +80,7 @@ def parse_ra_deg_to_hms(ra): ra /= 15.0 # convert to hours return f"{int(ra):02d}:{int((ra % 1) * 60):02d}:{((ra % 1) * 60) % 1 * 60:05.2f}" + def parse_dec_deg_to_dms(dec): """ Convert a Dec in degrees to a string in sexagesimal format (in dd:mm:ss). @@ -93,6 +91,7 @@ def parse_dec_deg_to_dms(dec): f"{int(dec):+03d}:{int((dec % 1) * 60):02d}:{((dec % 1) * 60) % 1 * 60:04.1f}" ) + def parse_ra_hms_to_deg(ra): """ Convert the input right ascension from sexagesimal string (hh:mm:ss format) into a float of decimal degrees. @@ -108,6 +107,7 @@ def parse_ra_hms_to_deg(ra): return ra + def parse_dec_dms_to_deg(dec): """ Convert the input declination from sexagesimal string (dd:mm:ss format) into a float of decimal degrees. diff --git a/util/util.py b/util/util.py index 37051f0a..b04c3e0f 100644 --- a/util/util.py +++ b/util/util.py @@ -1,11 +1,8 @@ -import pathlib import collections.abc -import sys import os import pathlib import git -from collections import defaultdict import numpy as np from datetime import datetime @@ -13,10 +10,8 @@ from astropy.io import fits from astropy.time import Time -from astropy import units as u -from astropy.coordinates import SkyCoord -from models.base import SmartSession, safe_mkdir +from models.base import safe_mkdir def ensure_file_does_not_exist( filepath, delete=False ): @@ -164,7 +159,7 @@ def parse_dateobs(dateobs=None, output='datetime'): The dateobs to parse. output: str Choose one of the output formats: - 'datetime', 'Time', 'float', 'str'. + 'datetime', 'Time', 'float', 'mjd', 'str'. Returns ------- @@ -191,7 +186,7 @@ def parse_dateobs(dateobs=None, output='datetime'): return dateobs.datetime elif output == 'Time': return dateobs - elif output == 'float': + elif output in ['float', 'mjd']: return dateobs.mjd elif output == 'str': return dateobs.isot @@ -390,9 +385,16 @@ def parse_bool(text): """Check if a string of text that represents a boolean value is True or False.""" if text is None: return False + if isinstance(text, bool): + return text elif text.lower() in ['true', 'yes', '1']: return True elif text.lower() in ['false', 'no', '0']: return False else: raise ValueError(f'Cannot parse boolean value from "{text}"') + + +def env_as_bool(varname): + """Parse an environmental variable as a boolean.""" + return parse_bool(os.getenv(varname))