diff --git a/alembic/versions/2023_05_31_1639-4114e36a2555_image_model.py b/alembic/versions/2023_05_31_1639-4114e36a2555_image_model.py index da955e90..419fa7fe 100644 --- a/alembic/versions/2023_05_31_1639-4114e36a2555_image_model.py +++ b/alembic/versions/2023_05_31_1639-4114e36a2555_image_model.py @@ -78,7 +78,7 @@ def upgrade() -> None: op.create_index(op.f('ix_images_telescope'), 'images', ['telescope'], unique=False) op.create_index(op.f('ix_images_type'), 'images', ['type'], unique=False) op.create_foreign_key(None, 'images', 'provenances', ['provenance_id'], ['id'], ondelete='CASCADE') - op.create_foreign_key(None, 'images', 'exposures', ['exposure_id'], ['id']) + op.create_foreign_key(None, 'images', 'exposures', ['exposure_id'], ['id'], ondelete='SET NULL') # ### end Alembic commands ### diff --git a/alembic/versions/2023_06_27_1350-b90b1e3ec58c_reference_table.py b/alembic/versions/2023_06_27_1350-b90b1e3ec58c_reference_table.py new file mode 100644 index 00000000..b532420e --- /dev/null +++ b/alembic/versions/2023_06_27_1350-b90b1e3ec58c_reference_table.py @@ -0,0 +1,213 @@ +"""reference table + +Revision ID: b90b1e3ec58c +Revises: 4114e36a2555 +Create Date: 2023-06-27 13:50:00.391100 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b90b1e3ec58c' +down_revision = '4114e36a2555' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('reference_images', + sa.Column('image_id', sa.BigInteger(), nullable=False), + sa.Column('target', sa.Text(), nullable=False), + sa.Column('filter', sa.Text(), nullable=False), + sa.Column('section_id', sa.Text(), nullable=False), + sa.Column('validity_start', sa.DateTime(), nullable=False), + sa.Column('validity_end', sa.DateTime(), nullable=False), + sa.Column('is_bad', sa.Boolean(), nullable=False), + sa.Column('bad_reason', sa.Text(), nullable=True), + sa.Column('bad_comment', sa.Text(), nullable=True), + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('modified', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['image_id'], ['images.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_reference_images_created_at'), 'reference_images', ['created_at'], unique=False) + op.create_index(op.f('ix_reference_images_filter'), 'reference_images', ['filter'], unique=False) + op.create_index(op.f('ix_reference_images_id'), 'reference_images', ['id'], unique=False) + op.create_index(op.f('ix_reference_images_image_id'), 'reference_images', ['image_id'], unique=False) + op.create_index(op.f('ix_reference_images_section_id'), 'reference_images', ['section_id'], unique=False) + op.create_index(op.f('ix_reference_images_target'), 'reference_images', ['target'], unique=False) + op.create_index(op.f('ix_reference_images_validity_end'), 'reference_images', ['validity_end'], unique=False) + op.create_index(op.f('ix_reference_images_validity_start'), 'reference_images', ['validity_start'], unique=False) + op.add_column('cutouts', sa.Column('source_list_id', sa.BigInteger(), nullable=False)) + op.add_column('cutouts', sa.Column('new_image_id', sa.BigInteger(), nullable=False)) + op.add_column('cutouts', sa.Column('ref_image_id', sa.BigInteger(), nullable=False)) + op.add_column('cutouts', sa.Column('sub_image_id', sa.BigInteger(), nullable=False)) + op.add_column('cutouts', sa.Column('pixel_x', sa.Integer(), nullable=False)) + op.add_column('cutouts', sa.Column('pixel_y', sa.Integer(), nullable=False)) + op.add_column('cutouts', sa.Column('provenance_id', sa.BigInteger(), nullable=False)) + op.add_column('cutouts', sa.Column('filepath', sa.Text(), nullable=False)) + op.add_column('cutouts', sa.Column('filepath_extensions', sa.ARRAY(sa.Text()), nullable=True)) + op.add_column('cutouts', sa.Column('format', sa.Enum('fits', 'hdf5', name='image_format'), nullable=False)) + op.add_column('cutouts', sa.Column('ra', sa.Double(), nullable=False)) + op.add_column('cutouts', sa.Column('dec', sa.Double(), nullable=False)) + op.add_column('cutouts', sa.Column('gallat', sa.Double(), nullable=True)) + op.add_column('cutouts', sa.Column('gallon', sa.Double(), nullable=True)) + op.add_column('cutouts', sa.Column('ecllat', sa.Double(), nullable=True)) + op.add_column('cutouts', sa.Column('ecllon', sa.Double(), nullable=True)) + op.create_index('cutouts_q3c_ang2ipix_idx', 'cutouts', [sa.text('q3c_ang2ipix(ra, dec)')], unique=False) + op.create_index(op.f('ix_cutouts_ecllat'), 'cutouts', ['ecllat'], unique=False) + op.create_index(op.f('ix_cutouts_filepath'), 'cutouts', ['filepath'], unique=True) + op.create_index(op.f('ix_cutouts_gallat'), 'cutouts', ['gallat'], unique=False) + op.create_index(op.f('ix_cutouts_new_image_id'), 'cutouts', ['new_image_id'], unique=False) + op.create_index(op.f('ix_cutouts_provenance_id'), 'cutouts', ['provenance_id'], unique=False) + op.create_index(op.f('ix_cutouts_ref_image_id'), 'cutouts', ['ref_image_id'], unique=False) + op.create_index(op.f('ix_cutouts_source_list_id'), 'cutouts', ['source_list_id'], unique=False) + op.create_index(op.f('ix_cutouts_sub_image_id'), 'cutouts', ['sub_image_id'], unique=False) + op.create_foreign_key(None, 'cutouts', 'images', ['ref_image_id'], ['id']) + op.create_foreign_key(None, 'cutouts', 'provenances', ['provenance_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'cutouts', 'images', ['new_image_id'], ['id']) + op.create_foreign_key(None, 'cutouts', 'images', ['sub_image_id'], ['id']) + op.create_foreign_key(None, 'cutouts', 'source_lists', ['source_list_id'], ['id']) + op.drop_index('exposure_q3c_ang2ipix_idx', table_name='exposures') + op.add_column('images', sa.Column('ref_image_id', sa.BigInteger(), nullable=True)) + op.add_column('images', sa.Column('new_image_id', sa.BigInteger(), nullable=True)) + op.drop_index('ix_images_combine_method', table_name='images') + op.create_index(op.f('ix_images_new_image_id'), 'images', ['new_image_id'], unique=False) + op.create_index(op.f('ix_images_ref_image_id'), 'images', ['ref_image_id'], unique=False) + op.create_foreign_key(None, 'images', 'images', ['new_image_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'images', 'images', ['ref_image_id'], ['id'], ondelete='CASCADE') + op.drop_column('images', 'combine_method') + op.add_column('measurements', sa.Column('cutouts_id', sa.BigInteger(), nullable=False)) + op.add_column('measurements', sa.Column('provenance_id', sa.BigInteger(), nullable=False)) + op.add_column('measurements', sa.Column('ra', sa.Double(), nullable=False)) + op.add_column('measurements', sa.Column('dec', sa.Double(), nullable=False)) + op.add_column('measurements', sa.Column('gallat', sa.Double(), nullable=True)) + op.add_column('measurements', sa.Column('gallon', sa.Double(), nullable=True)) + op.add_column('measurements', sa.Column('ecllat', sa.Double(), nullable=True)) + op.add_column('measurements', sa.Column('ecllon', sa.Double(), nullable=True)) + op.create_index(op.f('ix_measurements_cutouts_id'), 'measurements', ['cutouts_id'], unique=False) + op.create_index(op.f('ix_measurements_ecllat'), 'measurements', ['ecllat'], unique=False) + op.create_index(op.f('ix_measurements_gallat'), 'measurements', ['gallat'], unique=False) + op.create_index(op.f('ix_measurements_provenance_id'), 'measurements', ['provenance_id'], unique=False) + op.create_index('measurements_q3c_ang2ipix_idx', 'measurements', [sa.text('q3c_ang2ipix(ra, dec)')], unique=False) + op.create_foreign_key(None, 'measurements', 'provenances', ['provenance_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'measurements', 'cutouts', ['cutouts_id'], ['id']) + op.add_column('source_lists', sa.Column('image_id', sa.BigInteger(), nullable=False)) + op.add_column('source_lists', sa.Column('is_sub', sa.Boolean(), nullable=False)) + op.add_column('source_lists', sa.Column('provenance_id', sa.BigInteger(), nullable=False)) + op.add_column('source_lists', sa.Column('filepath', sa.Text(), nullable=False)) + op.add_column('source_lists', sa.Column('filepath_extensions', sa.ARRAY(sa.Text()), nullable=True)) + op.add_column('source_lists', sa.Column('format', sa.Enum('fits', 'hdf5', name='image_format'), nullable=False)) + op.create_index(op.f('ix_source_lists_filepath'), 'source_lists', ['filepath'], unique=True) + op.create_index(op.f('ix_source_lists_image_id'), 'source_lists', ['image_id'], unique=False) + op.create_index(op.f('ix_source_lists_provenance_id'), 'source_lists', ['provenance_id'], unique=False) + op.create_foreign_key(None, 'source_lists', 'provenances', ['provenance_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'source_lists', 'images', ['image_id'], ['id']) + op.add_column('world_coordinates', sa.Column('source_list_id', sa.BigInteger(), nullable=False)) + op.add_column('world_coordinates', sa.Column('provenance_id', sa.BigInteger(), nullable=False)) + op.create_index(op.f('ix_world_coordinates_provenance_id'), 'world_coordinates', ['provenance_id'], unique=False) + op.create_index(op.f('ix_world_coordinates_source_list_id'), 'world_coordinates', ['source_list_id'], unique=False) + op.create_foreign_key(None, 'world_coordinates', 'source_lists', ['source_list_id'], ['id']) + op.create_foreign_key(None, 'world_coordinates', 'provenances', ['provenance_id'], ['id'], ondelete='CASCADE') + op.add_column('zero_points', sa.Column('source_list_id', sa.BigInteger(), nullable=False)) + op.add_column('zero_points', sa.Column('provenance_id', sa.BigInteger(), nullable=False)) + op.create_index(op.f('ix_zero_points_provenance_id'), 'zero_points', ['provenance_id'], unique=False) + op.create_index(op.f('ix_zero_points_source_list_id'), 'zero_points', ['source_list_id'], unique=False) + op.create_foreign_key(None, 'zero_points', 'provenances', ['provenance_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'zero_points', 'source_lists', ['source_list_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'zero_points', type_='foreignkey') + op.drop_constraint(None, 'zero_points', type_='foreignkey') + op.drop_index(op.f('ix_zero_points_source_list_id'), table_name='zero_points') + op.drop_index(op.f('ix_zero_points_provenance_id'), table_name='zero_points') + op.drop_column('zero_points', 'provenance_id') + op.drop_column('zero_points', 'source_list_id') + op.drop_constraint(None, 'world_coordinates', type_='foreignkey') + op.drop_constraint(None, 'world_coordinates', type_='foreignkey') + op.drop_index(op.f('ix_world_coordinates_source_list_id'), table_name='world_coordinates') + op.drop_index(op.f('ix_world_coordinates_provenance_id'), table_name='world_coordinates') + op.drop_column('world_coordinates', 'provenance_id') + op.drop_column('world_coordinates', 'source_list_id') + op.drop_constraint(None, 'source_lists', type_='foreignkey') + op.drop_constraint(None, 'source_lists', type_='foreignkey') + op.drop_index(op.f('ix_source_lists_provenance_id'), table_name='source_lists') + op.drop_index(op.f('ix_source_lists_image_id'), table_name='source_lists') + op.drop_index(op.f('ix_source_lists_filepath'), table_name='source_lists') + op.drop_column('source_lists', 'format') + op.drop_column('source_lists', 'filepath_extensions') + op.drop_column('source_lists', 'filepath') + op.drop_column('source_lists', 'provenance_id') + op.drop_column('source_lists', 'is_sub') + op.drop_column('source_lists', 'image_id') + op.drop_constraint(None, 'measurements', type_='foreignkey') + op.drop_constraint(None, 'measurements', type_='foreignkey') + op.drop_index('measurements_q3c_ang2ipix_idx', table_name='measurements') + op.drop_index(op.f('ix_measurements_provenance_id'), table_name='measurements') + op.drop_index(op.f('ix_measurements_gallat'), table_name='measurements') + op.drop_index(op.f('ix_measurements_ecllat'), table_name='measurements') + op.drop_index(op.f('ix_measurements_cutouts_id'), table_name='measurements') + op.drop_column('measurements', 'ecllon') + op.drop_column('measurements', 'ecllat') + op.drop_column('measurements', 'gallon') + op.drop_column('measurements', 'gallat') + op.drop_column('measurements', 'dec') + op.drop_column('measurements', 'ra') + op.drop_column('measurements', 'provenance_id') + op.drop_column('measurements', 'cutouts_id') + op.add_column('images', sa.Column('combine_method', postgresql.ENUM('coadd', 'subtraction', name='image_combine_method'), autoincrement=False, nullable=True)) + op.drop_constraint(None, 'images', type_='foreignkey') + op.drop_constraint(None, 'images', type_='foreignkey') + op.drop_index(op.f('ix_images_ref_image_id'), table_name='images') + op.drop_index(op.f('ix_images_new_image_id'), table_name='images') + op.create_index('ix_images_combine_method', 'images', ['combine_method'], unique=False) + op.drop_column('images', 'new_image_id') + op.drop_column('images', 'ref_image_id') + op.create_index('exposure_q3c_ang2ipix_idx', 'exposures', [sa.text('q3c_ang2ipix(ra, "dec")')], unique=False) + op.drop_constraint(None, 'cutouts', type_='foreignkey') + op.drop_constraint(None, 'cutouts', type_='foreignkey') + op.drop_constraint(None, 'cutouts', type_='foreignkey') + op.drop_constraint(None, 'cutouts', type_='foreignkey') + op.drop_constraint(None, 'cutouts', type_='foreignkey') + op.drop_index(op.f('ix_cutouts_sub_image_id'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_source_list_id'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_ref_image_id'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_provenance_id'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_new_image_id'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_gallat'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_filepath'), table_name='cutouts') + op.drop_index(op.f('ix_cutouts_ecllat'), table_name='cutouts') + op.drop_index('cutouts_q3c_ang2ipix_idx', table_name='cutouts') + op.drop_column('cutouts', 'ecllon') + op.drop_column('cutouts', 'ecllat') + op.drop_column('cutouts', 'gallon') + op.drop_column('cutouts', 'gallat') + op.drop_column('cutouts', 'dec') + op.drop_column('cutouts', 'ra') + op.drop_column('cutouts', 'format') + op.drop_column('cutouts', 'filepath_extensions') + op.drop_column('cutouts', 'filepath') + op.drop_column('cutouts', 'provenance_id') + op.drop_column('cutouts', 'pixel_y') + op.drop_column('cutouts', 'pixel_x') + op.drop_column('cutouts', 'sub_image_id') + op.drop_column('cutouts', 'ref_image_id') + op.drop_column('cutouts', 'new_image_id') + op.drop_column('cutouts', 'source_list_id') + op.drop_index(op.f('ix_reference_images_validity_start'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_validity_end'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_target'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_section_id'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_image_id'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_id'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_filter'), table_name='reference_images') + op.drop_index(op.f('ix_reference_images_created_at'), table_name='reference_images') + op.drop_table('reference_images') + # ### end Alembic commands ### diff --git a/default_config.yaml b/default_config.yaml index 6c1c04ff..17462523 100644 --- a/default_config.yaml +++ b/default_config.yaml @@ -26,6 +26,7 @@ storage: # so you can use e.g., {ra_int:03d} to get a 3 digit zero padded right ascension. # The name convention can also include subfolders (e.g., using {ra_int}/...). # The minimal set of fields to make the filenames unique include: - # short_name (instrument name), date, time, section_id, prov_id (the unique provenance ID) - name_convention: "{ra_int:03d}/{short_name}_{date}_{time}_{section_id:02d}_{filter}_{prov_id:03d}" + # short_name (instrument name), date, time, section_id, prov_hash + # (in this example, the first six characters of the provenance unique hash) + name_convention: "{ra_int:03d}/{short_name}_{date}_{time}_{section_id}_{filter}_{prov_hash:.6s}" diff --git a/models/base.py b/models/base.py index 84d2a20e..1cd97a17 100644 --- a/models/base.py +++ b/models/base.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm.exc import DetachedInstanceError import util.config as config @@ -48,7 +49,7 @@ def Session(): f'@{cfg.value("db.host")}:{cfg.value("db.port")}/{cfg.value("db.database")}') _engine = sa.create_engine(url, future=True, poolclass=sa.pool.NullPool) - _Session = sessionmaker(bind=_engine, expire_on_commit=True) + _Session = sessionmaker(bind=_engine, expire_on_commit=False) session = _Session() @@ -84,6 +85,38 @@ def SmartSession(input_session=None): ) +def safe_merge(session, obj): + """ + Only merge the object if it has a valid ID, + and if it does not exist on the session. + Otherwise, return the object itself. + + Parameters + ---------- + session: sqlalchemy.orm.session.Session + The session to use for the merge. + obj: SeeChangeBase + The object to merge. + + Returns + ------- + obj: SeeChangeBase + The merged object, or the unmerged object + if it is already on the session or if it + doesn't have an ID. + """ + if obj is None: + return None + + if obj.id is None: + return obj + + if obj in session: + return obj + + return session.merge(obj) + + class SeeChangeBase: """Base class for all SeeChange classes.""" @@ -139,6 +172,52 @@ def get_attribute_list(self): return attrs + def recursive_merge(self, session, done_list=None): + """ + Recursively merge (using safe_merge) all the objects, + the parent objects (image, ref_image, new_image, etc.) + and the provenances of all of these, into the given session. + + Parameters + ---------- + session: sqlalchemy.orm.session.Session + The session to use for the merge. + done_list: list (optional) + A list of objects that have already been merged. + + Returns + ------- + SeeChangeBase + The merged object. + """ + if done_list is None: + done_list = set() + + if self in done_list: + return self + + obj = safe_merge(session, self) + done_list.add(obj) + + # only do the sub-properties if the object was already added to the session + attributes = ['provenance', 'exposure', 'image', 'ref_image', 'new_image', 'sub_image', 'source_list'] + + # recursively call this on the provenance and other parent objects + for att in attributes: + try: + sub_obj = getattr(self, att, None) + # go over lists: + if isinstance(sub_obj, list): + setattr(obj, att, [o.recursive_merge(session, done_list=done_list) for o in sub_obj]) + + if isinstance(sub_obj, SeeChangeBase): + setattr(obj, att, sub_obj.recursive_merge(session, done_list=done_list)) + + except DetachedInstanceError: + pass + + return obj + Base = declarative_base(cls=SeeChangeBase) @@ -329,7 +408,7 @@ def _validate_filepath(self, filepath): return filepath - def get_fullpath(self, download=True, as_list=False): + def get_fullpath(self, download=True, as_list=False, nofile=None): """ Get the full path of the file, or list of full paths of files if filepath_extensions is not None. @@ -360,6 +439,9 @@ def get_fullpath(self, download=True, as_list=False): as_list: bool Whether to return a list of filepaths, even if filepath_extensions=None. Default is False. + nofile: bool + Whether to check if the file exists on local disk. + Default is None, which means use the value of self.nofile. Returns ------- @@ -368,13 +450,16 @@ def get_fullpath(self, download=True, as_list=False): """ if self.filepath_extensions is None: if as_list: - return [self._get_fullpath_single(download)] + return [self._get_fullpath_single(download=download, nofile=nofile)] else: - return self._get_fullpath_single(download) + return self._get_fullpath_single(download=download, nofile=nofile) else: - return [self._get_fullpath_single(download, ext) for ext in self.filepath_extensions] + return [ + self._get_fullpath_single(download=download, ext=ext, nofile=nofile) + for ext in self.filepath_extensions + ] - def _get_fullpath_single(self, download=True, ext=None): + def _get_fullpath_single(self, download=True, ext=None, nofile=None): """ Get the full path of a single file. Will follow the same logic as get_fullpath(), @@ -388,13 +473,17 @@ def _get_fullpath_single(self, download=True, ext=None): Must have server_path defined. Default is True. ext: str Extension to add to the filepath. Default is None. - + nofile: bool + Whether to check if the file exists on local disk. + Default is None, which means use the value of self.nofile. Returns ------- str Full path to the file on local disk. """ - if not self.nofile and self.local_path is None: + if nofile is None: + nofile = self.nofile + if not nofile and self.local_path is None: raise ValueError("Local path not defined!") fname = self.filepath @@ -402,10 +491,10 @@ def _get_fullpath_single(self, download=True, ext=None): fname += ext fullname = os.path.join(self.local_path, fname) - if not self.nofile and not os.path.exists(fullname) and download and self.server_path is not None: + if not nofile and not os.path.exists(fullname) and download and self.server_path is not None: self._download_file(fname) - if not self.nofile and not os.path.exists(fullname): + if not nofile and not os.path.exists(fullname): raise FileNotFoundError(f"File {fullname} not found!") return fullname @@ -438,7 +527,8 @@ def remove_data_from_disk(self, remove_folders=True): """ if self.filepath is None: return - for f in self.get_fullpath(as_list=True): + # get the filepath, but don't check if the file exists! + for f in self.get_fullpath(as_list=True, nofile=True): if os.path.exists(f): os.remove(f) if remove_folders: @@ -459,6 +549,7 @@ class SpatiallyIndexed: """A mixin for tables that have ra and dec fields indexed via q3c.""" ra = sa.Column(sa.Double, nullable=False, doc='Right ascension in degrees') + dec = sa.Column(sa.Double, nullable=False, doc='Declination in degrees') gallat = sa.Column(sa.Double, index=True, doc="Galactic latitude of the target. ") diff --git a/models/cutouts.py b/models/cutouts.py index a4ebc262..b718a34a 100644 --- a/models/cutouts.py +++ b/models/cutouts.py @@ -1,6 +1,96 @@ -from models.base import Base +import sqlalchemy as sa +from sqlalchemy import orm +from models.base import Base, FileOnDiskMixin, SpatiallyIndexed + + +class Cutouts(Base, FileOnDiskMixin, SpatiallyIndexed): + + __tablename__ = 'cutouts' + + source_list_id = sa.Column( + sa.ForeignKey('source_lists.id'), + nullable=False, + index=True, + doc="ID of the source list this cutout is associated with. " + ) + + source_list = orm.relationship( + 'SourceList', + doc="The source list this cutout is associated with. " + ) + + new_image_id = sa.Column( + sa.ForeignKey('images.id'), + nullable=False, + index=True, + doc="ID of the new science image this cutout is associated with. " + ) + + new_image = orm.relationship( + 'Image', + primaryjoin="Cutouts.new_image_id==Image.id", + doc="The new science image this cutout is associated with. " + ) + + ref_image_id = sa.Column( + sa.ForeignKey('images.id'), + nullable=False, + index=True, + doc="ID of the reference image this cutout is associated with. " + ) + + ref_image = orm.relationship( + 'Image', + primaryjoin="Cutouts.ref_image_id==Image.id", + doc="The reference image this cutout is associated with. " + ) + + sub_image_id = sa.Column( + sa.ForeignKey('images.id'), + nullable=False, + index=True, + doc="ID of the subtraction image this cutout is associated with. " + ) + + sub_image = orm.relationship( + 'Image', + primaryjoin="Cutouts.sub_image_id==Image.id", + doc="The subtraction image this cutout is associated with. " + ) + + pixel_x = sa.Column( + sa.Integer, + nullable=False, + doc="X pixel coordinate of the center of the cutout. " + ) + + pixel_y = sa.Column( + sa.Integer, + nullable=False, + doc="Y pixel coordinate of the center of the cutout. " + ) + + provenance_id = sa.Column( + sa.ForeignKey('provenances.id', ondelete="CASCADE"), + nullable=False, + index=True, + doc=( + "ID of the provenance of this cutout. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this cutout. " + ) + ) + + provenance = orm.relationship( + 'Provenance', + cascade='save-update, merge, refresh-expire, expunge', + lazy='selectin', + doc=( + "Provenance of this cutout. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this cutout. " + ) + ) -class Cutouts(Base): - __tablename__ = 'cutouts' \ No newline at end of file diff --git a/models/image.py b/models/image.py index a331ba2d..1164fa5c 100644 --- a/models/image.py +++ b/models/image.py @@ -3,6 +3,7 @@ from sqlalchemy import orm from sqlalchemy.types import Enum from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm.exc import DetachedInstanceError from astropy.time import Time from astropy.wcs import WCS @@ -30,7 +31,7 @@ class Image(Base, FileOnDiskMixin, SpatiallyIndexed): __tablename__ = 'images' exposure_id = sa.Column( - sa.ForeignKey('exposures.id'), + sa.ForeignKey('exposures.id', ondelete='SET NULL'), nullable=True, index=True, doc=( @@ -62,26 +63,72 @@ class Image(Base, FileOnDiskMixin, SpatiallyIndexed): ) ) + ref_image_id = sa.Column( + sa.ForeignKey('images.id', ondelete="CASCADE"), + nullable=True, + index=True, + doc=( + "ID of the reference image used to produce a difference image. " + "Only set for difference images. This usually refers to a coadd image. " + ) + ) - @property - def is_multi_image(self): - if self.exposure is not None: - return False - elif self.source_images is not None and len(self.source_images) > 0: - return True - else: - return None # for new objects that have not defined either exposure or source_images + ref_image = orm.relationship( + 'Image', + cascade='save-update, merge, refresh-expire, expunge', + primaryjoin='images.c.ref_image_id == images.c.id', + uselist=False, + remote_side='Image.id', + doc=( + "Reference image used to produce a difference image. " + "Only set for difference images. This usually refers to a coadd image. " + ) + ) - combine_method = sa.Column( - Enum("coadd", "subtraction", name='image_combine_method'), + new_image_id = sa.Column( + sa.ForeignKey('images.id', ondelete="CASCADE"), nullable=True, index=True, doc=( - "Type of combination used to produce this multi-image object. " - "One of: coadd, subtraction. " + "ID of the new image used to produce a difference image. " + "Only set for difference images. This usually refers to a regular image. " + ) + ) + + new_image = orm.relationship( + 'Image', + cascade='save-update, merge, refresh-expire, expunge', + primaryjoin='images.c.new_image_id == images.c.id', + uselist=False, + remote_side='Image.id', + doc=( + "New image used to produce a difference image. " + "Only set for difference images. This usually refers to a regular image. " ) ) + @property + def is_coadd(self): + try: + if self.source_images is not None and len(self.source_images) > 0: + return True + except DetachedInstanceError: + if not self.is_sub and self.exposure_id is None: + return True + + return False + + @property + def is_sub(self): + try: + if self.ref_image is not None and self.new_image is not None: + return True + except DetachedInstanceError: + if self.ref_image_id is not None and self.new_image_id is not None: + return True + + return False + type = sa.Column( im_type_enum, # defined in models/exposure.py nullable=False, @@ -106,6 +153,7 @@ def is_multi_image(self): provenance = orm.relationship( 'Provenance', cascade='save-update, merge, refresh-expire, expunge', + lazy='selectin', doc=( "Provenance of this image. " "The provenance will contain a record of the code version" @@ -134,6 +182,14 @@ def is_multi_image(self): ) ) + @property + def observation_time(self): + """Translation of the MJD column to datetime object.""" + if self.mjd is None: + return None + else: + return Time(self.mjd, format='mjd').datetime + @property def start_mjd(self): """Time of the beginning of the exposure, or set of exposures (equal to mjd). """ @@ -332,6 +388,109 @@ def from_exposure(cls, exposure, section_id): return new + @classmethod + def from_images(cls, images): + """ + Create a new Image object from a list of other Image objects. + This is the first step in making a multi-image (usually a coadd). + The output image doesn't have any data, and is created with + nofile=True. It is up to the calling application to fill in the + data, flags, weight, etc. using the appropriate preprocessing tools. + After that, the data needs to be saved to file, and only then + can the new Image be added to the database. + + Parameters + ---------- + images: list of Image objects + The images to combine into a new Image object. + + Returns + ------- + output: Image + The new Image object. It would not have any data variables or filepath. + """ + if len(images) < 1: + raise ValueError("Must provide at least one image to combine.") + + output = Image(nofile=True) + + # for each attribute, check that all the images have the same value + for att in ['section_id', 'instrument', 'telescope', 'type', 'filter', 'project', 'target']: + values = set([getattr(image, att) for image in images]) + if len(values) != 1: + raise ValueError(f"Cannot combine images with different {att} values: {values}") + output.__setattr__(att, values.pop()) + # TODO: should RA and Dec also be exactly the same?? + output.ra = images[0].ra + output.dec = images[0].dec + + # exposure time is usually added together + output.exp_time = sum([image.exp_time for image in images]) + + # start MJD and end MJD + output.mjd = min([image.mjd for image in images]) + output.end_mjd = max([image.end_mjd for image in images]) + + # TODO: what about the header? should we combine them somehow? + output.header = images[0].header + output.raw_header = images[0].raw_header + + output.source_images = images + + # Note that "data" is not filled by this method, also the provenance is empty! + return output + + @classmethod + def from_ref_and_new(cls, ref, new): + """ + Create a new Image object from a reference Image object and a new Image object. + This is the first step in making a difference image. + The output image doesn't have any data, and is created with + nofile=True. It is up to the calling application to fill in the + data, flags, weight, etc. using the appropriate preprocessing tools. + After that, the data needs to be saved to file, and only then + can the new Image be added to the database. + + Parameters + ---------- + ref: Image object + The reference image to use. + new: Image object + The new image to use. + + Returns + ------- + output: Image + The new Image object. It would not have any data variables or filepath. + """ + output = Image(nofile=True) + + # for each attribute, check the two images have the same value + for att in ['section_id', 'instrument', 'telescope', 'type', 'filter', 'project', 'target']: + ref_value = getattr(ref, att) + new_value = getattr(new, att) + + if att == 'section_id': + ref_value = str(ref_value) + new_value = str(new_value) + if ref_value != new_value: + raise ValueError( + f"Cannot combine images with different {att} values: " + f"{ref_value} and {new_value}. " + ) + output.__setattr__(att, new_value) + # TODO: should RA and Dec also be exactly the same?? + + # get some more attributes from the new image + for att in ['exp_time', 'mjd', 'end_mjd', 'header', 'raw_header', 'ra', 'dec']: + output.__setattr__(att, getattr(new, att)) + + output.ref_image = ref + output.new_image = new + + # Note that "data" is not filled by this method, also the provenance is empty! + return output + @property def instrument_object(self): if self.instrument is not None: @@ -346,19 +505,21 @@ def instrument_object(self, value): def __repr__(self): + type_str = self.type + if self.is_coadd: + type_str += " (coadd)" + + if self.is_sub: + type_str += " (sub)" + output = ( f"Image(id: {self.id}, " - f"type: {self.type}, " + f"type: {type_str}, " f"exp: {self.exp_time}s, " f"filt: {self.filter}, " f"from: {self.instrument}/{self.telescope}" ) - multi_type = str(self.combine_method) if self.is_multi_image else None - - if multi_type is not None: - output += f", multi_type: {multi_type}" - output += ")" return output @@ -375,6 +536,9 @@ def invent_filename(self): # in which case we will need to parse it somehow, e.g., using some blocks # like , , etc. + if self.provenance is None: + raise ValueError("Cannot invent filename for image without provenance.") + t = Time(self.mjd, format='mjd', scale='utc').datetime short_name = self.instrument_object.get_short_instrument_name() @@ -395,12 +559,13 @@ def invent_filename(self): dec_frac = int(dec_frac) section_id = self.section_id - prov_id = self.provenance_id + prov_hash = self.provenance.unique_hash - default_convention = "{short_name}_{date}_{time}_{section_id:02d}_{filter}_{prov_id:03d}" + default_convention = "{short_name}_{date}_{time}_{section_id}_{filter}_{prov_hash:.6s}" cfg = config.Config.get() name_convention = cfg.value('storage.images.name_convention', default=None) + if name_convention is None: name_convention = default_convention @@ -418,7 +583,7 @@ def invent_filename(self): dec_int_pm=dec_int_pm, dec_frac=dec_frac, section_id=section_id, - prov_id=prov_id, + prov_hash=prov_hash, ) return filename @@ -438,8 +603,8 @@ def save(self, filename=None): if self.data is None: raise RuntimeError("The image data is not loaded. Cannot save.") - if self.provenance_id is None: - raise RuntimeError("The image provenance_id is not set. Cannot save.") + if self.provenance is None: + raise RuntimeError("The image provenance is not set. Cannot save.") if filename is None: filename = self.invent_filename() diff --git a/models/measurements.py b/models/measurements.py index 57d66ec0..7f60670e 100644 --- a/models/measurements.py +++ b/models/measurements.py @@ -1,6 +1,43 @@ -from models.base import Base +import sqlalchemy as sa +from sqlalchemy import orm +from models.base import Base, SpatiallyIndexed + + +class Measurements(Base, SpatiallyIndexed): + + __tablename__ = 'measurements' + + cutouts_id = sa.Column( + sa.ForeignKey('cutouts.id'), + nullable=False, + index=True, + doc="ID of the cutout this measurement is associated with. " + ) + + cutouts = orm.relationship( + 'Cutouts', + doc="The cutout this measurement is associated with. " + ) + + provenance_id = sa.Column( + sa.ForeignKey('provenances.id', ondelete="CASCADE"), + nullable=False, + index=True, + doc="ID of the provenance of this measurement. " + ) + + provenance = orm.relationship( + 'Provenance', + cascade='save-update, merge, refresh-expire, expunge', + lazy='selectin', + doc="The provenance of this measurement. " + ) + + # TODO: we need to decide what columns are actually saved. + # E.g., should we save a single flux or an array/JSONB of fluxes? + # Same thing for scores (e.g., R/B). + # Are analytical cuts saved with the "scores"? + # What about things like centroid positions / PSF widths? -class Measurements(Base): - __tablename__ = 'measurements' \ No newline at end of file diff --git a/models/provenance.py b/models/provenance.py index 9c1d66cc..c93ea8b9 100644 --- a/models/provenance.py +++ b/models/provenance.py @@ -1,4 +1,5 @@ import json +import base64 import hashlib import sqlalchemy as sa from sqlalchemy import event @@ -7,7 +8,7 @@ from pipeline.utils import get_git_hash -from models.base import Base, SmartSession +from models.base import Base, SeeChangeBase, SmartSession, safe_merge class CodeHash(Base): @@ -66,48 +67,6 @@ def update(self, session=None): class Provenance(Base): __tablename__ = "provenances" - def __init__(self, process=None, code_version=None, parameters=None, upstreams=None): - """ - Create a provenance object. - - Parameters - ---------- - process: str - Name of the process that created this provenance object. - Examples can include: "calibration", "subtraction", "source extraction" or just "level1". - code_version: CodeVersion - Version of the code used to create this provenance object. - parameters: dict - Dictionary of parameters used in the process. - Include only the critical parameters that affect the final products. - upstreams: list of Provenance - List of provenance objects that this provenance object is dependent on. - """ - if process is None: - raise ValueError('Provenance must have a process name. ') - else: - self.process = process - if not isinstance(code_version, CodeVersion): - raise ValueError(f'Code version must be a models.CodeVersion. Got {type(code_version)}.') - else: - self.code_version = code_version - - if parameters is None: - self.parameters = {} - else: - self.parameters = parameters - - if upstreams is None: - self.upstreams = [] - else: - if not isinstance(upstreams, list): - self.upstreams = [upstreams] - if len(upstreams) > 0: - if isinstance(upstreams[0], Provenance): - self.upstreams = upstreams - else: - raise ValueError('upstreams must be a list of Provenance objects') - process = sa.Column( sa.String, nullable=False, @@ -141,20 +100,20 @@ def __init__(self, process=None, code_version=None, parameters=None, upstreams=N secondary=provenance_self_association_table, primaryjoin='provenances.c.id == provenance_upstreams.c.downstream_id', secondaryjoin='provenances.c.id == provenance_upstreams.c.upstream_id', - back_populates="downstreams", + # back_populates="downstreams", passive_deletes=True, - lazy='selectin', # should be able to get upstream_ids without a session! + lazy='selectin', # should be able to get upstream_hashes without a session! ) - downstreams = relationship( - "Provenance", - secondary=provenance_self_association_table, - primaryjoin='provenances.c.id == provenance_upstreams.c.upstream_id', - secondaryjoin='provenances.c.id == provenance_upstreams.c.downstream_id', - back_populates="upstreams", - passive_deletes=True, - # can add lazy='selectin' here, but probably not need it - ) + # downstreams = relationship( + # "Provenance", + # secondary=provenance_self_association_table, + # primaryjoin='provenances.c.id == provenance_upstreams.c.upstream_id', + # secondaryjoin='provenances.c.id == provenance_upstreams.c.downstream_id', + # back_populates="upstreams", + # passive_deletes=True, + # # can add lazy='selectin' here, but probably not need it + # ) CodeVersion.provenances = relationship( "Provenance", @@ -182,22 +141,152 @@ def upstream_ids(self): ids.sort() return ids + @property + def upstream_hashes(self): + if self.upstreams is None: + return [] + else: + hashes = set([u.unique_hash for u in self.upstreams]) + hashes = list(hashes) + hashes.sort() + return hashes + + def __init__(self, process=None, code_version=None, parameters=None, upstreams=None): + """ + Create a provenance object. + + Parameters + ---------- + process: str + Name of the process that created this provenance object. + Examples can include: "calibration", "subtraction", "source extraction" or just "level1". + code_version: CodeVersion + Version of the code used to create this provenance object. + parameters: dict + Dictionary of parameters used in the process. + Include only the critical parameters that affect the final products. + upstreams: list of Provenance + List of provenance objects that this provenance object is dependent on. + """ + SeeChangeBase.__init__(self) + + if process is None: + raise ValueError('Provenance must have a process name. ') + else: + self.process = process + if not isinstance(code_version, CodeVersion): + raise ValueError(f'Code version must be a models.CodeVersion. Got {type(code_version)}.') + else: + self.code_version = code_version + + if parameters is None: + self.parameters = {} + else: + self.parameters = parameters + + if upstreams is None: + self.upstreams = [] + else: + if not isinstance(upstreams, list): + self.upstreams = [upstreams] + if len(upstreams) > 0: + if isinstance(upstreams[0], Provenance): + self.upstreams = upstreams + else: + raise ValueError('upstreams must be a list of Provenance objects') + + def __repr__(self): + return ( + '' + ) + def update_hash(self): """ - Update the unique_hash using the code_version, parameters and upstream_ids. + Update the unique_hash using the code_version, parameters and upstream_hashes. """ - if self.process is None or self.parameters is None or self.upstream_ids is None or self.code_version is None: - raise ValueError('Provenance must have process, code_version, parameters and upstream_ids defined. ') + if self.process is None or self.parameters is None or self.code_version is None: + raise ValueError('Provenance must have process, code_version, and parameters defined. ') superdict = dict( process=self.process, parameters=self.parameters, - upstream_ids=self.upstream_ids, + upstream_hashes=self.upstream_hashes, code_version=self.code_version.version ) json_string = json.dumps(superdict, sort_keys=True) - self.unique_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest() + self.unique_hash = base64.urlsafe_b64encode(hashlib.sha256(json_string.encode("utf-8")).digest()).decode()[:20] + + @classmethod + def get_code_version(cls, session=None): + """ + Get the most relevant or latest code version. + Tries to match the current git hash with a CodeHash + instance, but if that doesn't work (e.g., if the + code is running on a machine without git) then + the latest CodeVersion is returned. + + Parameters + ---------- + session: SmartSession + SQLAlchemy session object. If None, a new session is created, + and closed as soon as the function finishes. + + Returns + ------- + code_version: CodeVersion + CodeVersion object + """ + code_hash = session.scalars(sa.select(CodeHash).where(CodeHash.hash == get_git_hash())).first() + if code_hash is not None: + code_version = code_hash.code_version + else: + code_version = session.scalars(sa.select(CodeVersion).order_by(CodeVersion.version.desc())).first() + + return code_version + + def recursive_merge(self, session, done_list=None): + """ + Recursively merge this object, its CodeVersion, + and any upstream/downstream provenances into + the given session. + + Parameters + ---------- + session: SmartSession + SQLAlchemy session object to merge into. + + Returns + ------- + merged_provenance: Provenance + The merged provenance object. + """ + if done_list is None: + done_list = set() + + if self in done_list: + return self + + merged_self = safe_merge(session, self) + done_list.add(merged_self) + + merged_self.code_version = safe_merge(session, merged_self.code_version) + + merged_self.upstreams = [ + u.recursive_merge(session, done_list=done_list) for u in merged_self.upstreams if u is not None + ] + + # merged_self.downstreams = [ + # d.recursive_merge(session, done_list=done_list) for d in merged_self.downstreams if d is not None + # ] + + return merged_self @event.listens_for(Provenance, "before_insert") diff --git a/models/references.py b/models/references.py new file mode 100644 index 00000000..d0ec093d --- /dev/null +++ b/models/references.py @@ -0,0 +1,90 @@ +import sqlalchemy as sa +from sqlalchemy import orm + +from models.base import Base + + +class ReferenceEntry(Base): + """ + A table that refers to each reference Image object, + based on the validity time range, and the object/field it is targeting. + """ + + __tablename__ = 'reference_images' + + image_id = sa.Column( + sa.ForeignKey('images.id', ondelete='CASCADE'), + nullable=False, + index=True, + doc="ID of the reference image this object is referring to. " + ) + + image = orm.relationship( + 'Image', + doc="The reference image this entry is referring to. " + ) + + target = sa.Column( + sa.Text, + nullable=False, + index=True, + doc=( + 'Name of the target object or field id. ' + 'This string is used to match the reference to new images, ' + 'e.g., by matching the field ID on a pre-defined grid of fields. ' + ) + ) + + filter = sa.Column( + sa.Text, + nullable=False, + index=True, + doc="Filter used to make the images for this reference image. " + ) + + section_id = sa.Column( + sa.Text, + nullable=False, + index=True, + doc="Section ID of the reference image. " + ) + + validity_start = sa.Column( + sa.DateTime, + nullable=False, + index=True, + doc="The start of the validity time range of this reference image. " + ) + + validity_end = sa.Column( + sa.DateTime, + nullable=False, + index=True, + doc="The end of the validity time range of this reference image. " + ) + + is_bad = sa.Column( + sa.Boolean, + nullable=False, + default=False, + doc="Whether this reference image is bad. " + ) + + bad_reason = sa.Column( + sa.Text, + nullable=True, + doc=( + "The reason why this reference image is bad. " + "Should be a single pharse or a comma-separated list of reasons. " + ) + ) + + bad_comment = sa.Column( + sa.Text, + nullable=True, + doc="Any additional comments about why this reference image is bad. " + ) + + # this table doesn't have provenance. + # The underlying image will have its own provenance for the "coaddition" process. + diff --git a/models/source_list.py b/models/source_list.py index 30d294ec..b5692f91 100644 --- a/models/source_list.py +++ b/models/source_list.py @@ -1,6 +1,62 @@ +import uuid +import sqlalchemy as sa +from sqlalchemy import orm -from models.base import Base +from models.base import Base, FileOnDiskMixin -class SourceList(Base): +class SourceList(Base, FileOnDiskMixin): + __tablename__ = 'source_lists' + + image_id = sa.Column( + sa.ForeignKey('images.id'), + nullable=False, + index=True, + doc="ID of the image this source list was generated from. " + ) + + image = orm.relationship( + 'Image', + doc="The image this source list was generated from. " + ) + + is_sub = sa.Column( + sa.Boolean, + nullable=False, + default=False, + doc=( + "Whether this source list is from a subtraction image (detections), " + "or from a regular image (sources, the default). " + ) + ) + + provenance_id = sa.Column( + sa.ForeignKey('provenances.id', ondelete="CASCADE"), + nullable=False, + index=True, + doc=( + "ID of the provenance of this source list. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this source list. " + ) + ) + + provenance = orm.relationship( + 'Provenance', + cascade='save-update, merge, refresh-expire, expunge', + lazy='selectin', + doc=( + "Provenance of this source list. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this source list. " + ) + ) + + def save(self): + """ + Save this source list to the database. + """ + # TODO: Must implement this at some point! + self.filepath = uuid.uuid4().hex + diff --git a/models/world_coordinates.py b/models/world_coordinates.py index 0fd7d1a3..5cb4abc8 100644 --- a/models/world_coordinates.py +++ b/models/world_coordinates.py @@ -1,6 +1,44 @@ +import sqlalchemy as sa +from sqlalchemy import orm + from models.base import Base class WorldCoordinates(Base): __tablename__ = 'world_coordinates' + + source_list_id = sa.Column( + sa.ForeignKey('source_lists.id'), + nullable=False, + index=True, + doc="ID of the source list this world coordinate system is associated with. " + ) + + source_list = orm.relationship( + 'SourceList', + doc="The source list this world coordinate system is associated with. " + ) + + provenance_id = sa.Column( + sa.ForeignKey('provenances.id', ondelete="CASCADE"), + nullable=False, + index=True, + doc=( + "ID of the provenance of this world coordinate system. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this world coordinate system. " + ) + ) + + provenance = orm.relationship( + 'Provenance', + cascade='save-update, merge, refresh-expire, expunge', + lazy='selectin', + doc=( + "Provenance of this world coordinate system. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this world coordinate system. " + ) + ) + diff --git a/models/zero_point.py b/models/zero_point.py index 90009fdf..70acfc21 100644 --- a/models/zero_point.py +++ b/models/zero_point.py @@ -1,6 +1,43 @@ +import sqlalchemy as sa +from sqlalchemy import orm + from models.base import Base class ZeroPoint(Base): __tablename__ = 'zero_points' + + source_list_id = sa.Column( + sa.ForeignKey('source_lists.id'), + nullable=False, + index=True, + doc="ID of the source list this zero point is associated with. " + ) + + source_list = orm.relationship( + 'SourceList', + doc="The source list this zero point is associated with. " + ) + + provenance_id = sa.Column( + sa.ForeignKey('provenances.id', ondelete="CASCADE"), + nullable=False, + index=True, + doc=( + "ID of the provenance of this zero point. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this zero point. " + ) + ) + + provenance = orm.relationship( + 'Provenance', + cascade='save-update, merge, refresh-expire, expunge', + lazy='selectin', + doc=( + "Provenance of this zero point. " + "The provenance will contain a record of the code version" + "and the parameters used to produce this zero point. " + ) + ) diff --git a/pipeline/__init__.py b/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pipeline/astrometry.py b/pipeline/astro_cal.py similarity index 82% rename from pipeline/astrometry.py rename to pipeline/astro_cal.py index 360ad14e..c7aa9ad7 100644 --- a/pipeline/astrometry.py +++ b/pipeline/astro_cal.py @@ -5,7 +5,7 @@ from models.world_coordinates import WorldCoordinates -class ParsAstrometry(Parameters): +class ParsAstroCalibrator(Parameters): def __init__(self, **kwargs): super().__init__() self.cross_match_catalog = self.add_par( @@ -21,13 +21,12 @@ def __init__(self, **kwargs): self.override(kwargs) def get_process_name(self): - return 'astrometry' + return 'astro_cal' -class Astrometry: +class AstroCalibrator: def __init__(self, **kwargs): - self.pars = ParsAstrometry() - + self.pars = ParsAstroCalibrator() def run(self, *args, **kwargs): """ @@ -60,6 +59,14 @@ def run(self, *args, **kwargs): # TODO: save a WorldCoordinates object to database # TODO: update the image's FITS header with the wcs + wcs = WorldCoordinates() + wcs.source_list = sources + if wcs.provenance is None: + wcs.provenance = prov + else: + if wcs.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for wcs and provenance!') + # add the resulting object to the data store ds.wcs = wcs diff --git a/pipeline/cutter.py b/pipeline/cutting.py similarity index 90% rename from pipeline/cutter.py rename to pipeline/cutting.py index e660d9df..697726b7 100644 --- a/pipeline/cutter.py +++ b/pipeline/cutting.py @@ -68,6 +68,12 @@ def run(self, *args, **kwargs): # Commit the results to the database. # add the resulting list to the data store + if cutout_list.provenance is None: + cutout_list.provenance = prov + else: + if cutout_list.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for cutout_list and provenance!') + ds.cutouts = cutout_list # make sure this is returned to be used in the next step diff --git a/pipeline/data_store.py b/pipeline/data_store.py index 84615061..d89868fc 100644 --- a/pipeline/data_store.py +++ b/pipeline/data_store.py @@ -1,14 +1,15 @@ import sqlalchemy as sa -from pipeline.utils import get_git_hash, get_latest_provenance, parse_session +from pipeline.utils import get_latest_provenance, parse_session -from models.base import SmartSession -from models.provenance import CodeHash, CodeVersion, Provenance +from models.base import SmartSession, FileOnDiskMixin, safe_merge +from models.provenance import CodeVersion, Provenance from models.exposure import Exposure from models.image import Image from models.source_list import SourceList from models.world_coordinates import WorldCoordinates from models.zero_point import ZeroPoint +from models.references import ReferenceEntry from models.cutouts import Cutouts from models.measurements import Measurements @@ -16,20 +17,20 @@ UPSTREAM_NAMES = { 'preprocessing': [], 'extraction': ['preprocessing'], - 'astrometry': ['extraction'], - 'calibration': ['extraction', 'astrometry'], - 'subtraction': ['preprocessing', 'extraction', 'astrometry', 'calibration'], + 'astro_cal': ['extraction'], + 'photo_cal': ['extraction', 'astro_cal'], + 'subtraction': ['preprocessing', 'extraction', 'astro_cal', 'photo_cal'], 'detection': ['subtraction'], 'cutting': ['detection'], - 'measurement': ['detection', 'calibration'], + 'measurement': ['detection', 'photo_cal'], } UPSTREAM_OBJECTS = { 'preprocessing': 'image', 'coaddition': 'image', 'extraction': 'sources', - 'astrometry': 'wcs', - 'calibration': 'zp', + 'astro_cal': 'wcs', + 'photo_cal': 'zp', 'subtraction': 'sub_image', 'detection': 'detections', 'cutting': 'cutouts', @@ -60,8 +61,8 @@ def from_args(*args, **kwargs): if len(args) == 1 and isinstance(args[0], DataStore): return args[0], None if ( - len(args) == 2 and isinstance(args[0], DataStore) - and isinstance(args[1], (sa.orm.session.Session, SmartSession)) + len(args) == 2 and isinstance(args[0], DataStore) and + (isinstance(args[1], sa.orm.session.Session) or args[1] is None) ): return args[0], args[1] else: @@ -123,15 +124,12 @@ def parse_args(self, *args, **kwargs): the function that received the session as one of the arguments. If no session is given, will return None. """ - if len(args) == 0: - raise ValueError('Must provide at least one argument to DataStore constructor.') - if len(args) == 1 and isinstance(args[0], DataStore): # if the only argument is a DataStore, copy it self.__dict__ = args[0].__dict__.copy() return - output_session = parse_session(*args, **kwargs) + args, kwargs, output_session = parse_session(*args, **kwargs) # remove any provenances from the args list for arg in args: @@ -141,7 +139,9 @@ def parse_args(self, *args, **kwargs): # parse the args list arg_types = [type(arg) for arg in args] - if arg_types == [int, int] or arg_types == [int, str]: # exposure_id, section_id + if arg_types == []: # no arguments, quietly skip + pass + elif arg_types == [int, int] or arg_types == [int, str]: # exposure_id, section_id self.exposure_id, self.section_id = args elif arg_types == [int]: self.image_id = args[0] @@ -179,58 +179,58 @@ def __setattr__(self, key, value): """ Check some of the inputs before saving them. """ + if value is not None: + if key in ['exposure_id', 'image_id'] and not isinstance(value, int): + raise ValueError(f'{key} must be an integer, got {type(value)}') - if key in ['exposure_id', 'image_id'] and not isinstance(value, int): - raise ValueError(f'{key} must be an integer, got {type(value)}') + if key in ['section_id'] and not isinstance(value, (int, str)): + raise ValueError(f'{key} must be an integer or a string, got {type(value)}') - if key in ['section_id'] and not isinstance(value, (int, str)): - raise ValueError(f'{key} must be an integer or a string, got {type(value)}') + if key == 'image' and not isinstance(value, Image): + raise ValueError(f'image must be an Image object, got {type(value)}') - if key == 'image' and not isinstance(value, Image): - raise ValueError(f'image must be an Image object, got {type(value)}') + if key == 'sources' and not isinstance(value, SourceList): + raise ValueError(f'sources must be a SourceList object, got {type(value)}') - if key == 'sources' and not isinstance(value, SourceList): - raise ValueError(f'sources must be a SourceList object, got {type(value)}') + if key == 'wcs' and not isinstance(value, WorldCoordinates): + raise ValueError(f'WCS must be a WorldCoordinates object, got {type(value)}') - if key == 'wcs' and not isinstance(value, WorldCoordinates): - raise ValueError(f'WCS must be a WorldCoordinates object, got {type(value)}') + if key == 'zp' and not isinstance(value, ZeroPoint): + raise ValueError(f'ZP must be a ZeroPoint object, got {type(value)}') - if key == 'zp' and not isinstance(value, ZeroPoint): - raise ValueError(f'ZP must be a ZeroPoint object, got {type(value)}') + if key == 'ref_image' and not isinstance(value, Image): + raise ValueError(f'ref_image must be an Image object, got {type(value)}') - if key == 'ref_image' and not isinstance(value, Image): - raise ValueError(f'ref_image must be an Image object, got {type(value)}') + if key == 'sub_image' and not isinstance(value, Image): + raise ValueError(f'sub_image must be a Image object, got {type(value)}') - if key == 'sub_image' and not isinstance(value, Image): - raise ValueError(f'sub_image must be a Image object, got {type(value)}') + if key == 'detections' and not isinstance(value, SourceList): + raise ValueError(f'detections must be a SourceList object, got {type(value)}') - if key == 'detections' and not isinstance(value, SourceList): - raise ValueError(f'detections must be a SourceList object, got {type(value)}') + if key == 'cutouts' and not isinstance(value, list): + raise ValueError(f'cutouts must be a list of Cutout objects, got {type(value)}') - if key == 'cutouts' and not isinstance(value, list): - raise ValueError(f'cutouts must be a list of Cutout objects, got {type(value)}') + if key == 'cutouts' and not all([isinstance(c, Cutouts) for c in value]): + raise ValueError(f'cutouts must be a list of Cutouts objects, got list with {[type(c) for c in value]}') - if key == 'cutouts' and not all([isinstance(c, Cutouts) for c in value]): - raise ValueError(f'cutouts must be a list of Cutouts objects, got list with {[type(c) for c in value]}') + if key == 'measurements' and not isinstance(value, list): + raise ValueError(f'measurements must be a list of Measurements objects, got {type(value)}') - if key == 'measurements' and not isinstance(value, list): - raise ValueError(f'measurements must be a list of Measurements objects, got {type(value)}') - - if key == 'measurements' and not all([isinstance(m, Measurements) for m in value]): - raise ValueError( - f'measurements must be a list of Measurement objects, got list with {[type(m) for m in value]}' - ) + if key == 'measurements' and not all([isinstance(m, Measurements) for m in value]): + raise ValueError( + f'measurements must be a list of Measurement objects, got list with {[type(m) for m in value]}' + ) - if key == 'upstream_provs' and not isinstance(value, list): - raise ValueError(f'upstream_provs must be a list of Provenance objects, got {type(value)}') + if key == 'upstream_provs' and not isinstance(value, list): + raise ValueError(f'upstream_provs must be a list of Provenance objects, got {type(value)}') - if key == 'upstream_provs' and not all([isinstance(p, Provenance) for p in value]): - raise ValueError( - f'upstream_provs must be a list of Provenance objects, got list with {[type(p) for p in value]}' - ) + if key == 'upstream_provs' and not all([isinstance(p, Provenance) for p in value]): + raise ValueError( + f'upstream_provs must be a list of Provenance objects, got list with {[type(p) for p in value]}' + ) - if key == 'session' and not isinstance(value, (sa.orm.session.Session, SmartSession)): - raise ValueError(f'Session must be a SQLAlchemy session or SmartSession, got {type(value)}') + if key == 'session' and not isinstance(value, (sa.orm.session.Session, SmartSession)): + raise ValueError(f'Session must be a SQLAlchemy session or SmartSession, got {type(value)}') super().__setattr__(key, value) @@ -247,6 +247,13 @@ def get_inputs(self): def get_provenance(self, process, pars_dict, upstream_provs=None, session=None): """ Get the provenance for a given process. + Will try to find a provenance that matches the current code version + and the parameter dictionary, and if it doesn't find it, + it will create a new Provenance object. + + This function should be called externally by applications + using the DataStore, to get the provenance for a given process, + or to make it if it doesn't exist. Parameters ---------- @@ -276,24 +283,23 @@ def get_provenance(self, process, pars_dict, upstream_provs=None, session=None): ------- prov: Provenance The provenance for the given process. - """ if upstream_provs is None: upstream_provs = self.upstream_provs with SmartSession(session) as session: - # check if this code version exists - code_hash = session.scalars(sa.select(CodeHash).where(CodeHash.hash == get_git_hash())).first() - if code_hash is None: - raise ValueError('Cannot find code hash!') + code_version = Provenance.get_code_version(session=session) + if code_version is None: + # this "null" version should never be used in production + code_version = CodeVersion(version='v0.0.0') + code_version.update() # try to add current git hash to version object # check if we can find the upstream provenances upstreams = [] - for name in self.UPSTREAM_NAMES[process]: - obj = getattr(self, UPSTREAM_OBJECTS[name], None) - + for name in UPSTREAM_NAMES[process]: # first try to load an upstream that was given explicitly: - if name in [p.process for p in upstream_provs]: + obj = getattr(self, UPSTREAM_OBJECTS[name], None) + if upstream_provs is not None and name in [p.process for p in upstream_provs]: prov = [p for p in upstream_provs if p.process == name][0] # second, try to get a provenance from objects saved to the store: @@ -311,10 +317,13 @@ def get_provenance(self, process, pars_dict, upstream_provs=None, session=None): upstreams.append(prov) + if len(upstreams) != len(UPSTREAM_NAMES[process]): + raise ValueError(f'Could not find all upstream provenances for process {process}.') + # we have a code version object and upstreams, we can make a provenance prov = Provenance( process=process, - code_version=code_hash.code_version, + code_version=code_version, parameters=pars_dict, upstreams=upstreams, ) @@ -330,25 +339,36 @@ def get_provenance(self, process, pars_dict, upstream_provs=None, session=None): return prov - def _get_provenance_fallback(self, process, session=None): + def _get_provanance_for_an_upstream(self, process, session=None): """ Get the provenance for a given process, without knowing the parameters or code version. This simply looks for a matching provenance in the upstream_provs attribute, and if it is not there, it will call the latest provenance - from the database. + (for that process) from the database. + This is used to get the provenance of upstream objects, + only when those objects are not found in the store. + Example: when looking for the upstream provenance of a + photo_cal process, the upstream process is preprocess, + so this function will look for the preprocess provenance. + If the ZP object is from the DB then there must be provenance + objects for the Image that was used to create it. + If the ZP was just created, the Image should also be + in memory even if the provenance is not on DB yet, + in which case this function should not be called. This will raise if no provenance can be found. """ # see if it is in the upstream_provs - prov_list = [p for p in self.upstream_provs if p.process == process] - provenance = prov_list[0] if len(prov_list) > 0 else None + if self.upstream_provs is not None: + prov_list = [p for p in self.upstream_provs if p.process == process] + provenance = prov_list[0] if len(prov_list) > 0 else None + else: + provenance = None # try getting the latest from the database if provenance is None: # check latest provenance provenance = get_latest_provenance(process, session=session) - if provenance is None: # still can't find anything! - raise ValueError(f'Cannot find the "{process}" provenance!') return provenance @@ -404,9 +424,10 @@ def get_image(self, provenance=None, session=None): The image object, or None if no matching image is found. """ + process_name = 'preprocessing' # we were explicitly asked for a specific image id: if self.image_id is not None: - if self.image is not None and isinstance(self.image, Image) and self.image.id == self.image_id: + if isinstance(self.image, Image) and self.image.id == self.image_id: pass # return self.image at the end of function... else: # not found in local memory, get from DB with SmartSession(session) as session: @@ -418,33 +439,41 @@ def get_image(self, provenance=None, session=None): # this option is for when we are not sure which image id we need elif self.exposure_id is not None and self.section_id is not None: - - # must compare the image to the current provenance - if provenance is None: # check if in upstream_provs/database - provenance = self._get_provenance_fallback('preprocessing', session=session) - + # check if self.image is the correct image: if ( - self.image is not None and isinstance(self.image, Image) - and self.image.exposure_id == self.exposure_id and self.image.section_id == str(self.section_id) + isinstance(self.image, Image) and self.image.exposure_id == self.exposure_id + and self.image.section_id == str(self.section_id) ): # make sure the image has the correct provenance if self.image is not None: if self.image.provenance is None: raise ValueError('Image has no provenance!') - - # a mismatch of provenance and cached image: - if self.image.provenance.unique_hash != provenance.unique_hash: - self.image = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) == 1: + # a mismatch of provenance and cached image: + if self.image.provenance.unique_hash != provenances[0].unique_hash: + self.image = None # this must be an old image, get a new one if self.image is None: # load from DB - with SmartSession(session) as session: - self.image = session.scalars( - sa.select(Image).where( - Image.exposure_id == self.exposure_id, - Image.section_id == str(self.section_id), - Image.provenance.has(unique_hash=provenance.unique_hash) - ) - ).first() + # this happens when the image is required as an upstream for another process (but isn't in memory) + if provenance is None: # check if in upstream_provs/database + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if we can't find a provenance, then we don't need to load from DB + with SmartSession(session) as session: + self.image = session.scalars( + sa.select(Image).where( + Image.exposure_id == self.exposure_id, + Image.section_id == str(self.section_id), + Image.provenance.has(unique_hash=provenance.unique_hash) + ) + ).first() else: raise ValueError('Cannot get processed image without exposure_id and section_id or image_id!') @@ -476,30 +505,40 @@ def get_sources(self, provenance=None, session=None): or None if no matching source list is found. """ - # not in memory, look for it on the DB + process_name = 'extraction' + # if sources exists in memory, check the provenance is ok if self.sources is not None: - - if provenance is None: # check if in upstream_provs/database - provenance = self._get_provenance_fallback('extraction', session=session) - - # make sure the wcs has the correct provenance + # make sure the sources object has the correct provenance if self.sources.provenance is None: raise ValueError('SourceList has no provenance!') - # a mismatch of provenance and cached image: - if self.sources.provenance.unique_hash != provenance.unique_hash: - self.sources = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one {process_name} provenance found!') + if len(provenances) == 1: + # a mismatch of given provenance and self.sources' provenance: + if self.sources.provenance.unique_hash != provenances[0].unique_hash: + self.sources = None # this must be an old sources object, get a new one + # not in memory, look for it on the DB if self.sources is None: - with SmartSession(session) as session: - image = self.get_image(session=session) - self.sources = session.scalars( - sa.select(SourceList).where( - SourceList.image_id == image.id, - SourceList.is_sub.is_(False), - SourceList.provenance.has(unique_hash=provenance.unique_hash), - ) - ).first() + # this happens when the source list is required as an upstream for another process (but isn't in memory) + if provenance is None: # check if in upstream_provs/database + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if we can't find a provenance, then we don't need to load from DB + with SmartSession(session) as session: + image = self.get_image(session=session) + self.sources = session.scalars( + sa.select(SourceList).where( + SourceList.image_id == image.id, + SourceList.is_sub.is_(False), + SourceList.provenance.has(unique_hash=provenance.unique_hash), + ) + ).first() return self.sources @@ -515,7 +554,7 @@ def get_wcs(self, provenance=None, session=None): This provenance should be consistent with the current code version and critical parameters. If none is given, will use the latest provenance - for the "astrometry" process. + for the "astro_cal" process. session: sqlalchemy.orm.session.Session or SmartSession An optional session to use for the database query. If not given, will open a new session and close it at @@ -527,28 +566,37 @@ def get_wcs(self, provenance=None, session=None): The WCS object, or None if no matching WCS is found. """ + process_name = 'astro_cal' # make sure the wcs has the correct provenance if self.wcs is not None: - if provenance is None: # check if in upstream_provs/database - provenance = self._get_provenance_fallback('astrometry', session=session) - if self.wcs.provenance is None: raise ValueError('WorldCoordinates has no provenance!') - - # a mismatch of provenance and cached image: - if self.wcs.provenance.unique_hash != provenance.unique_hash: - self.wcs = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) == 1: + # a mismatch of provenance and cached wcs: + if self.wcs.provenance.unique_hash != provenances[0].unique_hash: + self.wcs = None # this must be an old wcs object, get a new one # not in memory, look for it on the DB if self.wcs is None: with SmartSession(session) as session: - sources = self.get_sources(session=session) - self.wcs = session.scalars( - sa.select(WorldCoordinates).where( - WorldCoordinates.source_list_id == sources.id, - WorldCoordinates.provenance.has(unique_hash=provenance.unique_hash), - ) - ).first() + # this happens when the wcs is required as an upstream for another process (but isn't in memory) + if provenance is None: # check if in upstream_provs/database + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if None, it means we can't find it on the DB + sources = self.get_sources(session=session) + self.wcs = session.scalars( + sa.select(WorldCoordinates).where( + WorldCoordinates.source_list_id == sources.id, + WorldCoordinates.provenance.has(unique_hash=provenance.unique_hash), + ) + ).first() return self.wcs @@ -564,7 +612,7 @@ def get_zp(self, provenance=None, session=None): This provenance should be consistent with the current code version and critical parameters. If none is given, will use the latest provenance - for the "calibration" process. + for the "photo_cal" process. session: sqlalchemy.orm.session.Session or SmartSession An optional session to use for the database query. If not given, will open a new session and close it at @@ -574,32 +622,40 @@ def get_zp(self, provenance=None, session=None): ------- wcs: ZeroPoint object The photometric calibration object, or None if no matching ZP is found. - """ + process_name = 'photo_cal' # make sure the zp has the correct provenance if self.zp is not None: - if provenance is None: # check if in upstream_provs/database - provenance = self._get_provenance_fallback('calibration', session=session) - if self.zp.provenance is None: raise ValueError('ZeroPoint has no provenance!') - # a mismatch of provenance and cached image: - if self.zp.provenance.unique_hash != provenance.unique_hash: - self.zp = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) == 1: + # a mismatch of provenance and cached zp: + if self.zp.provenance.unique_hash != provenances[0].unique_hash: + self.zp = None # this must be an old zp, get a new one # not in memory, look for it on the DB if self.zp is None: with SmartSession(session) as session: sources = self.get_sources(session=session) # TODO: do we also need the astrometric solution (to query for the ZP)? - - self.zp = session.scalars( - sa.select(ZeroPoint).where( - ZeroPoint.source_list_id == sources.id, - ZeroPoint.provenance.has(unique_hash=provenance.unique_hash), - ) - ).first() + # this happens when the wcs is required as an upstream for another process (but isn't in memory) + if provenance is None: # check if in upstream_provs/database + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if None, it means we can't find it on the DB + self.zp = session.scalars( + sa.select(ZeroPoint).where( + ZeroPoint.source_list_id == sources.id, + ZeroPoint.provenance.has(unique_hash=provenance.unique_hash), + ) + ).first() return self.zp @@ -610,7 +666,7 @@ def get_reference_image(self, provenance=None, session=None): Parameters ---------- provenance: Provenance object - The provenance to use for the subtraction. + The provenance to use for the coaddition. This provenance should be consistent with the current code version and critical parameters. If none is given, will use the latest provenance @@ -630,12 +686,28 @@ def get_reference_image(self, provenance=None, session=None): with SmartSession(session) as session: image = self.get_image(session=session) - self.ref_image = session.scalars( - sa.select(Image).where( - # TODO: we need to figure out exactly how to match reference to image + + ref_entry = session.scalars( + sa.select(ReferenceEntry).where( + sa.or_( + ReferenceEntry.validity_start.is_(None), + ReferenceEntry.validity_start <= image.observation_time + ), + sa.or_( + ReferenceEntry.validity_end.is_(None), + ReferenceEntry.validity_end >= image.observation_time + ), + ReferenceEntry.filter == image.filter, + ReferenceEntry.target == image.target, + ReferenceEntry.is_bad.is_(False), ) ).first() + if ref_entry is None: + raise ValueError(f'No reference image found for image {image.id}') + + self.ref_image = ref_entry.image + return self.ref_image def get_subtraction(self, provenance=None, session=None): @@ -662,30 +734,40 @@ def get_subtraction(self, provenance=None, session=None): or None if no matching subtraction image is found. """ + process_name = 'subtraction' # make sure the subtraction has the correct provenance if self.sub_image is not None: - if provenance is None: # check if in upstream_provs/database - provenance = self._get_provenance_fallback('subtraction', session=session) - if self.sub_image.provenance is None: raise ValueError('Subtraction image has no provenance!') - - # a mismatch of provenance and cached image: - if self.sub_image.provenance.unique_hash != provenance.unique_hash: - self.sub_image = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) > 0: + # a mismatch of provenance and cached subtraction image: + if self.sub_image.provenance.unique_hash != provenances[0].unique_hash: + self.sub_image = None # this must be an old subtraction image, need to get a new one # not in memory, look for it on the DB if self.sub_image is None: with SmartSession(session) as session: image = self.get_image(session=session) ref = self.get_reference_image(session=session) - self.sub_image = session.scalars( - sa.select(Image).where( - Image.ref_id == ref.id, - Image.new_id == image.id, - Image.provenance.has(unique_hash=provenance.unique_hash), - ) - ).first() + + # this happens when the subtraction is required as an upstream for another process (but isn't in memory) + if provenance is None: # check if in upstream_provs/database + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if None, it means we can't find it on the DB + self.sub_image = session.scalars( + sa.select(Image).where( + Image.ref_image_id == ref.id, + Image.new_image_id == image.id, + Image.provenance.has(unique_hash=provenance.unique_hash), + ) + ).first() return self.sub_image @@ -701,7 +783,7 @@ def get_detections(self, provenance=None, session=None): This provenance should be consistent with the current code version and critical parameters. If none is given, will use the latest provenance - for the "extraction" process. + for the "detection" process. session: sqlalchemy.orm.session.Session or SmartSession An optional session to use for the database query. If not given, will open a new session and close it at @@ -710,34 +792,44 @@ def get_detections(self, provenance=None, session=None): Returns ------- sl: SourceList object - The list of sources for this image (the catalog), + The list of sources for this subtraction image (the catalog), or None if no matching source list is found. """ + process_name = 'detection' # not in memory, look for it on the DB if self.detections is not None: - - if provenance is None: # check if in upstream_provs/database - provenance = self._get_provenance_fallback('detection', session=session) - # make sure the wcs has the correct provenance if self.detections.provenance is None: raise ValueError('SourceList has no provenance!') - # a mismatch of provenance and cached image: - if self.detections.provenance.unique_hash != provenance.unique_hash: - self.detections = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) == 1: + # a mismatch of provenance and cached detections: + if self.detections.provenance.unique_hash != provenances[0].unique_hash: + self.detections = None # this must be an old detections object, need to get a new one if self.detections is None: with SmartSession(session) as session: - image = self.get_image(session=session) - self.detections = session.scalars( - sa.select(SourceList).where( - SourceList.image_id == image.id, - SourceList.is_sub.is_(True), - SourceList.provenance.has(unique_hash=provenance.unique_hash), - ) - ).first() + sub_image = self.get_subtraction(session=session) + + # this happens when the wcs is required as an upstream for another process (but isn't in memory) + if provenance is None: # check if in upstream_provs/database + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if None, it means we can't find it on the DB + self.detections = session.scalars( + sa.select(SourceList).where( + SourceList.image_id == sub_image.id, + SourceList.is_sub.is_(True), + SourceList.provenance.has(unique_hash=provenance.unique_hash), + ) + ).first() return self.detections @@ -764,29 +856,39 @@ def get_cutouts(self, provenance=None, session=None): The list of measurements, or None if no matching measurements are found. """ + process_name = 'cutting' # make sure the cutouts have the correct provenance if self.cutouts is not None: - if provenance is None: - provenance = self._get_provenance_fallback('measurement', session=session) - if any([c.provenance is None for c in self.cutouts]): raise ValueError('One of the Cutouts has no provenance!') - # a mismatch of provenance and cached image: - if any([c.provenance.unique_hash != provenance.unique_hash for c in self.cutouts]): - self.cutouts = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) == 1: + # a mismatch of provenance and cached cutouts: + if any([c.provenance.unique_hash != provenances[0].unique_hash for c in self.cutouts]): + self.cutouts = None # this must be an old cutouts list, need to get a new one # not in memory, look for it on the DB if self.cutouts is None: with SmartSession(session) as session: - image = self.get_subtraction(session=session) + sub_image = self.get_subtraction(session=session) - self.cutouts = session.scalars( - sa.select(Cutouts).where( - Cutouts.sub_image_id == image.id, - Cutouts.provenance.has(unique_hash=provenance.unique_hash), - ) - ).all() + # this happens when the cutouts are required as an upstream for another process (but aren't in memory) + if provenance is None: + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if None, it means we can't find it on the DB + self.cutouts = session.scalars( + sa.select(Cutouts).where( + Cutouts.sub_image_id == sub_image.id, + Cutouts.provenance.has(unique_hash=provenance.unique_hash), + ) + ).all() return self.cutouts @@ -813,28 +915,140 @@ def get_measurements(self, provenance=None, session=None): The list of measurements, or None if no matching measurements are found. """ + process_name = 'measurement' # make sure the measurements have the correct provenance if self.measurements is not None: - if provenance is None: - provenance = self._get_provenance_fallback('measurement', session=session) - if any([m.provenance is None for m in self.measurements]): raise ValueError('One of the Measurements has no provenance!') - # a mismatch of provenance and cached image: - if any([m.provenance.unique_hash != provenance.unique_hash for m in self.measurements]): - self.measurements = None + if self.upstream_provs is not None: + provenances = [p for p in self.upstream_provs if p.process == process_name] + else: + provenances = [] + if len(provenances) > 1: + raise ValueError(f'More than one "{process_name}" provenance found!') + if len(provenances) == 1: + # a mismatch of provenance and cached image: + if any([m.provenance.unique_hash != provenances[0].unique_hash for m in self.measurements]): + self.measurements = None # not in memory, look for it on the DB if self.measurements is None: with SmartSession(session) as session: cutouts = self.get_cutouts(session=session) cutout_ids = [c.id for c in cutouts] - self.measurements = session.scalars( - sa.select(Measurements).where( - Measurements.cutouts_id.in_(cutout_ids), - Measurements.provenance.has(unique_hash=provenance.unique_hash), - ) - ).all() + + # this happens when the measurements are required as an upstream (but aren't in memory) + if provenance is None: + provenance = self._get_provanance_for_an_upstream(process_name, session=session) + + if provenance is not None: # if None, it means we can't find it on the DB + self.measurements = session.scalars( + sa.select(Measurements).where( + Measurements.cutouts_id.in_(cutout_ids), + Measurements.provenance.has(unique_hash=provenance.unique_hash), + ) + ).all() return self.measurements + + def get_all_data_products(self, output='dict'): + """ + Get all the data products associated with this Exposure. + By default, this returns a dict with named entries. + If using output='list', will return a flattened list of all + objects, including lists (e.g., Cutouts will be concatenated, + no nested). Any None values will be removed. + + Parameters + ---------- + output: str, optional + The output format. Can be 'dict' or 'list'. + Default is 'dict'. + + Returns + ------- + data_products: dict or list + A dict with named entries, or a flattened list of all + objects, including lists (e.g., Cutouts will be concatenated, + no nested). Any None values will be removed. + """ + attributes = ['exposure', 'image', 'sources', 'wcs', 'zp', 'sub_image', 'detections', 'cutouts', 'measurements'] + result = {att: getattr(self, att) for att in attributes} + if output == 'dict': + return result + if output == 'list': + list_result = [] + for k, v in result.items(): + if isinstance(v, list): + list_result.extend(v) + else: + list_result.append(v) + + return [v for v in list_result if v is not None] + + else: + raise ValueError(f'Unknown output format: {output}') + + def save_and_commit(self, session=None): + """ + Go over all the data products and add them to the session. + If any of the data products are associated with a file on disk, + that would be saved as well. + + Parameters + ---------- + session: sqlalchemy.orm.session.Session or SmartSession + An optional session to use for the database query. + If not given, will open a new session and close it at + the end of the function. + Note that this method calls session.commit() + """ + with SmartSession(session) as session: + autoflush_state = session.autoflush + try: + # session.autoflush = False + for obj in self.get_all_data_products(output='list'): + # print(f'saving {obj} with provenance: {getattr(obj, "provenance", None)}') + + if isinstance(obj, FileOnDiskMixin): + obj.save() + + obj = obj.recursive_merge(session) + session.add(obj) + + session.commit() + finally: + session.autoflush = autoflush_state + + def delete_everything(self, session=None): + """ + Delete everything associated with this sub-image. + All data products in the data store are removed from the DB, + and all files on disk are deleted. + + Parameters + ---------- + session: sqlalchemy.orm.session.Session or SmartSession + An optional session to use for the database query. + If not given, will open a new session and close it at + the end of the function. + Note that this method calls session.commit() + """ + with SmartSession(session) as session: + autoflush_state = session.autoflush + try: + session.autoflush = False + for obj in self.get_all_data_products(output='list'): + # if hasattr(obj, 'provenance'): + # print(f'Deleting {obj} with provenance= {obj.provenance}') + obj = safe_merge(session, obj) + if isinstance(obj, FileOnDiskMixin): + obj.remove_data_from_disk() + if obj in session: + session.delete(obj) + + session.commit() + finally: + session.autoflush = autoflush_state + diff --git a/pipeline/detector.py b/pipeline/detection.py similarity index 75% rename from pipeline/detector.py rename to pipeline/detection.py index c5c11692..7ef83af8 100644 --- a/pipeline/detector.py +++ b/pipeline/detection.py @@ -32,7 +32,7 @@ def __init__(self, **kwargs): self.override(kwargs) - def _get_process_name(self): + def get_process_name(self): if self.subtraction: return 'detection' else: @@ -64,7 +64,7 @@ def run(self, *args, **kwargs): # or load using the provenance given in the # data store's upstream_provs, or just use # the most recent provenance for "subtraction" - image = ds.get_subtraction_image(session=session) + image = ds.get_subtraction(session=session) if image is None: raise ValueError( @@ -72,7 +72,15 @@ def run(self, *args, **kwargs): ) detections = self.extract_sources(image) + detections.image = image + if detections.provenance is None: + detections.provenance = prov + else: + if detections.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for detections and provenance!') + + detections.is_sub = True ds.detections = detections else: # regular image @@ -89,6 +97,12 @@ def run(self, *args, **kwargs): raise ValueError(f'Cannot find an image corresponding to the datastore inputs: {ds.get_inputs()}') sources = self.extract_sources(image) + sources.image = image + if sources.provenance is None: + sources.provenance = prov + else: + if sources.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for sources and provenance!') ds.sources = sources @@ -101,4 +115,12 @@ def extract_sources(self, image): sources = SourceList() - return sources \ No newline at end of file + return sources + + +if __name__ == '__main__': + from models.base import Session + from models.provenance import Provenance + session = Session() + source_lists = session.scalars(sa.select(SourceList)).all() + prov = session.scalars(sa.select(Provenance)).all() diff --git a/pipeline/measurer.py b/pipeline/measurement.py similarity index 86% rename from pipeline/measurer.py rename to pipeline/measurement.py index e4616cd7..42e565e6 100644 --- a/pipeline/measurer.py +++ b/pipeline/measurement.py @@ -58,9 +58,9 @@ def run(self, *args, **kwargs): prov = ds.get_provenance(self.pars.get_process_name(), self.pars.get_critical_pars(), session=session) # try to find some measurements in memory or in the database: - ments = ds.get_measurements(prov, session=session) + measurements = ds.get_measurements(prov, session=session) - if ments is None: # must create a new list of Measurements + if measurements is None: # must create a new list of Measurements # use the latest source list in the data store, # or load using the provenance given in the @@ -80,7 +80,13 @@ def run(self, *args, **kwargs): # Commit the results to the database. # add the resulting list to the data store - ds.measurements = ments + if measurements.provenance is None: + measurements.provenance = prov + else: + if measurements.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for measurements and provenance!') + + ds.measurements = measurements # make sure this is returned to be used in the next step return ds diff --git a/pipeline/parameters.py b/pipeline/parameters.py index ab4f7036..5198b7e3 100644 --- a/pipeline/parameters.py +++ b/pipeline/parameters.py @@ -12,6 +12,7 @@ # If the embedded object's Parameters doesn't have any of these # then that key is just skipped + class Parameters: """ Keep track of parameters for any of the pipeline classes. @@ -116,13 +117,6 @@ def __init__(self, **kwargs): self.__critical__ = {} self.__aliases__ = {} - self.code_version = self.add_par( - "code_version", - 'v0.0.0', - str, - "Version of the code used to produce the output products.", - ) - self.verbose = self.add_par( "verbose", 0, int, "Level of verbosity (0=quiet).", critical=False ) @@ -632,13 +626,18 @@ def get_process_name(self): """ raise NotImplementedError("Must be implemented in subclass.") - def get_provenance(self, prov_cache=None, session=None): + # TODO: seems like this is no longer used, instead call DataStore.get_provenance() + def get_provenance(self, code_version=None, prov_cache=None, session=None): """ Get a Provenance object based on the parameters and code version. Parameters ---------- + code_version: str + The version of the code that was used to generate + the provenance. If not given, will use the version + of the current code. prov_cache: dict A dictionary of Provenance objects, from which the relevant upstream ids can be retrieved. If not given, will be filled @@ -667,7 +666,7 @@ def get_provenance(self, prov_cache=None, session=None): upstreams.append(upstream_prov) process = self.get_process_name() # only works in subclasses! - cv = session.scalars(sa.select(CodeVersion).where(CodeVersion.name == self.code_version)).first() + cv = session.scalars(sa.select(CodeVersion).where(CodeVersion.name == code_version)).first() if cv is not None: cv.update() # update the current commit hash @@ -679,7 +678,7 @@ def get_provenance(self, prov_cache=None, session=None): if cv is None: # TODO: should this generate a new code version? Should that be done manually? - raise ValueError(f'Cannot find code version "{self.code_version}" for process "{process}"') + raise ValueError(f'Cannot find code version "{code_version}" for process "{process}"') # now that we have a code version object we can make a provenance prov = Provenance( diff --git a/pipeline/calibrator.py b/pipeline/photo_cal.py similarity index 84% rename from pipeline/calibrator.py rename to pipeline/photo_cal.py index 0addff41..c8e1884b 100644 --- a/pipeline/calibrator.py +++ b/pipeline/photo_cal.py @@ -5,7 +5,7 @@ from models.zero_point import ZeroPoint -class ParsCalibrator(Parameters): +class ParsPhotCalibrator(Parameters): def __init__(self, **kwargs): super().__init__() self.cross_match_catalog = self.add_par( @@ -21,12 +21,12 @@ def __init__(self, **kwargs): self.override(kwargs) def get_process_name(self): - return 'calibration' + return 'photo_cal' -class Calibrator: +class PhotCalibrator: def __init__(self, **kwargs): - self.pars = ParsCalibrator() + self.pars = ParsPhotCalibrator() def run(self, *args, **kwargs): """ @@ -55,6 +55,7 @@ def run(self, *args, **kwargs): raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}') wcs = ds.get_wcs(session=session) + if wcs is None: raise ValueError( f'Cannot find an astrometric solution corresponding to the datastore inputs: {ds.get_inputs()}' @@ -65,6 +66,14 @@ def run(self, *args, **kwargs): # TODO: save a ZeroPoint object to database # TODO: update the image's FITS header with the zp + zp = ZeroPoint() + zp.source_list = sources + if zp.provenance is None: + zp.provenance = prov + else: + if zp.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for zp and provenance!') + # update the data store with the new ZeroPoint ds.zp = zp diff --git a/pipeline/preprocessor.py b/pipeline/preprocessing.py similarity index 76% rename from pipeline/preprocessor.py rename to pipeline/preprocessing.py index 1e175dcf..054c308e 100644 --- a/pipeline/preprocessor.py +++ b/pipeline/preprocessing.py @@ -1,4 +1,4 @@ - +import numpy as np import sqlalchemy as sa from models.base import SmartSession @@ -48,13 +48,22 @@ def run(self, *args, **kwargs): if image is None: # need to make new image exposure = ds.get_raw_exposure(session=session) - # TODO: get the CCD image from the exposure - image = Image(exposure_id=exposure.id, section_id=ds.section_id, provenance=ds.provenances['preprocessing']) + # get the CCD image from the exposure + image = Image.from_exposure(exposure, ds.section_id) + image.data = image.raw_data - np.median(image.raw_data) # TODO: replace this! if image is None: raise ValueError('Image cannot be None at this point!') - # TODO: apply dark/flat/sky subtraction + # TODO: apply dark/flat/sky subtraction + # right now this is just a placeholder: + + + if image.provenance is None: + image.provenance = prov + else: + if image.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for image and provenance!') ds.image = image diff --git a/pipeline/subtractor.py b/pipeline/subtraction.py similarity index 83% rename from pipeline/subtractor.py rename to pipeline/subtraction.py index fe07e9f3..a46a505b 100644 --- a/pipeline/subtractor.py +++ b/pipeline/subtraction.py @@ -61,7 +61,13 @@ def run(self, *args, **kwargs): f'Cannot find a reference image corresponding to the datastore inputs: {ds.get_inputs()}' ) - sub_image = image - ref # TODO: I think there should be a little more to it than this :) + sub_image = Image.from_ref_and_new(ref, image) + sub_image.data = image.data - ref.data # TODO: implement the subtraction algorithm here + if sub_image.provenance is None: + sub_image.provenance = prov + else: + if sub_image.provenance.unique_hash != prov.unique_hash: + raise ValueError('Provenance mismatch for sub_image and provenance!') ds.sub_image = sub_image diff --git a/pipeline/pipeline.py b/pipeline/top_level.py similarity index 56% rename from pipeline/pipeline.py rename to pipeline/top_level.py index a718be63..6359ffca 100644 --- a/pipeline/pipeline.py +++ b/pipeline/top_level.py @@ -2,13 +2,13 @@ from pipeline.parameters import Parameters from pipeline.data_store import DataStore -from pipeline.preprocessor import Preprocessor -from pipeline.astrometry import Astrometry -from pipeline.calibrator import Calibrator -from pipeline.subtractor import Subtractor -from pipeline.detector import Detector -from pipeline.cutter import Cutter -from pipeline.measurer import Measurer +from pipeline.preprocessing import Preprocessor +from pipeline.astro_cal import AstroCalibrator +from pipeline.photo_cal import PhotCalibrator +from pipeline.subtraction import Subtractor +from pipeline.detection import Detector +from pipeline.cutting import Cutter +from pipeline.measurement import Measurer # should this come from db.py instead? @@ -37,52 +37,53 @@ def __init__(self, **kwargs): self.pars.augment(kwargs.get('pipeline', {})) # dark/flat and sky subtraction tools - preprocessor_config = config.get('preprocessor', {}) - preprocessor_config.update(kwargs.get('preprocessor', {})) - self.pars.add_defaults_to_dict(preprocessor_config) - self.preprocessor = Preprocessor(**preprocessor_config) + preprocessing_config = config.get('preprocessing', {}) + preprocessing_config.update(kwargs.get('preprocessing', {})) + self.pars.add_defaults_to_dict(preprocessing_config) + self.preprocessor = Preprocessor(**preprocessing_config) # source detection ("extraction" for the regular image!) - extractor_config = config.get('extractor', {}) - extractor_config.update(kwargs.get('extractor', {})) - self.pars.add_defaults_to_dict(extractor_config) - self.extractor = Detector(**extractor_config) + extraction_config = config.get('extraction', {}) + extraction_config.update(kwargs.get('extraction', {})) + self.pars.add_defaults_to_dict(extraction_config) + self.extractor = Detector(**extraction_config) # astrometric fit using a first pass of sextractor and then astrometric fit to Gaia - astrometry_config = config.get('astrometry', {}) - astrometry_config.update(kwargs.get('astrometry', {})) - self.pars.add_defaults_to_dict(astrometry_config) - self.astrometry = Astrometry(**astrometry_config) + astro_cal_config = config.get('astro_cal', {}) + astro_cal_config.update(kwargs.get('astro_cal', {})) + self.pars.add_defaults_to_dict(astro_cal_config) + self.astro_cal = AstroCalibrator(**astro_cal_config) # photometric calibration: - calibrator_config = config.get('calibrator', {}) - calibrator_config.update(kwargs.get('calibrator', {})) - self.pars.add_defaults_to_dict(calibrator_config) - self.calibrator = Calibrator(**calibrator_config) + photo_cal_config = config.get('photo_cal', {}) + photo_cal_config.update(kwargs.get('photo_cal', {})) + self.pars.add_defaults_to_dict(photo_cal_config) + self.photo_cal = PhotCalibrator(**photo_cal_config) # reference fetching and image subtraction - subtractor_config = config.get('subtractor', {}) - subtractor_config.update(kwargs.get('subtractor', {})) - self.pars.add_defaults_to_dict(subtractor_config) - self.subtractor = Subtractor(**subtractor_config) + subtraction_config = config.get('subtraction', {}) + subtraction_config.update(kwargs.get('subtraction', {})) + self.pars.add_defaults_to_dict(subtraction_config) + self.subtractor = Subtractor(**subtraction_config) # source detection ("detection" for the subtracted image!) - detector_config = config.get('detector', {}) - detector_config.update(kwargs.get('detector', {})) - self.pars.add_defaults_to_dict(detector_config) - self.detector = Detector(**detector_config) + detection_config = config.get('detection', {}) + detection_config.update(kwargs.get('detection', {})) + self.pars.add_defaults_to_dict(detection_config) + self.detector = Detector(**detection_config) + self.detector.pars.subtraction = True # produce cutouts for detected sources: - cutter_config = config.get('cutter', {}) - cutter_config.update(kwargs.get('cutter', {})) - self.pars.add_defaults_to_dict(cutter_config) - self.cutter = Cutter(**cutter_config) + cutting_config = config.get('cutting', {}) + cutting_config.update(kwargs.get('cutting', {})) + self.pars.add_defaults_to_dict(cutting_config) + self.cutter = Cutter(**cutting_config) # measure photometry, analytical cuts, and deep learning models on the Cutouts: - measurer_config = config.get('measurer', {}) - measurer_config.update(kwargs.get('extractor', {})) - self.pars.add_defaults_to_dict(measurer_config) - self.measurer = Measurer(**measurer_config) + measurement_config = config.get('measurement', {}) + measurement_config.update(kwargs.get('extraction', {})) + self.pars.add_defaults_to_dict(measurement_config) + self.measurer = Measurer(**measurement_config) def run(self, *args, **kwargs): """ @@ -100,20 +101,20 @@ def run(self, *args, **kwargs): ds = self.extractor.run(ds, session) # find astrometric solution, save WCS into Image object and FITS headers - ds = self.astrometry.run(ds, session) + ds = self.astro_cal.run(ds, session) # cross-match against photometric catalogs and get zero point, save into Image object and FITS headers - ds = self.calibrator.run(ds, session) + ds = self.photo_cal.run(ds, session) # fetch reference images and subtract them, save SubtractedImage objects to DB and disk ds = self.subtractor.run(ds, session) - # make cutouts of all the sources in the "detections" source list - ds = self.cutter.run(ds, session) - # find sources, generate a source list for detections ds = self.detector.run(ds, session) + # make cutouts of all the sources in the "detections" source list + ds = self.cutter.run(ds, session) + # extract photometry, analytical cuts, and deep learning models on the Cutouts: ds = self.measurer.run(ds, session) @@ -127,3 +128,4 @@ def run_with_session(self): """ with SmartSession() as session: self.run(session=session) + diff --git a/pipeline/utils.py b/pipeline/utils.py index e8d0c119..09bacada 100644 --- a/pipeline/utils.py +++ b/pipeline/utils.py @@ -46,7 +46,10 @@ def get_latest_provenance(process_name, session=None): process_name: str Name of the process that created this provenance object. Examples can include: "calibration", "subtraction", "source extraction" or just "level1". - session: sqlalchemy.orm.session.Session or SmartSession + session: sqlalchemy.orm.session.Session + Session to use to query the database. + If not given, a new session will be created, + and will be closed at the end of the function. Returns ------- diff --git a/tests/conftest.py b/tests/conftest.py index 424a1950..0bb9f69b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,10 +7,14 @@ import sqlalchemy as sa +from astropy.time import Time + from models.base import SmartSession, CODE_ROOT from models.provenance import CodeVersion, Provenance from models.exposure import Exposure from models.image import Image +from models.references import ReferenceEntry + def rnd_str(n): return ''.join(np.random.choice(list('abcdefghijklmnopqrstuvwxyz'), n)) @@ -18,8 +22,11 @@ def rnd_str(n): @pytest.fixture(scope="session", autouse=True) def code_version(): - cv = CodeVersion(version="test_v1.0.0") - cv.update() + with SmartSession() as session: + cv = session.scalars(sa.select(CodeVersion).where(CodeVersion.version == 'test_v1.0.0')).first() + if cv is None: + cv = CodeVersion(version="test_v1.0.0") + cv.update() yield cv @@ -45,7 +52,7 @@ def provenance_base(code_version): yield p with SmartSession() as session: - session.execute(sa.delete(Provenance).where(Provenance.id == pid)) + # session.execute(sa.delete(Provenance).where(Provenance.id == pid)) session.commit() @@ -77,7 +84,7 @@ def factory(): f"Demo_test_{rnd_str(5)}.fits", section_id=0, exp_time=np.random.randint(1, 4) * 10, # 10 to 40 seconds - mjd=np.random.uniform(58300, 58500), + mjd=np.random.uniform(58000, 58500), filter=np.random.choice(list('grizY')), ra=np.random.uniform(0, 360), dec=np.random.uniform(-90, 90), @@ -109,6 +116,7 @@ def make_exposure_file(exposure): if fullname is not None and os.path.isfile(fullname): os.remove(fullname) + @pytest.fixture def exposure(exposure_factory): e = exposure_factory() @@ -159,3 +167,72 @@ def demo_image(exposure): session.execute(sa.delete(Image).where(Image.id == im.id)) session.commit() im.remove_data_from_disk(remove_folders=True) + + +@pytest.fixture +def reference_entry(exposure_factory, provenance_base, provenance_extra): + ref_entry = None + try: # remove files and DB entries at the end + filter = np.random.choice(list('grizY')) + target = rnd_str(6) + ra = np.random.uniform(0, 360) + dec = np.random.uniform(-90, 90) + images = [] + + for i in range(5): + exp = exposure_factory() + + exp.filter = filter + exp.target = target + exp.project = "coadd_test" + exp.ra = ra + exp.dec = dec + + exp.update_instrument() + im = Image.from_exposure(exp, section_id=0) + im.data = im.raw_data - np.median(im.raw_data) + im.provenance = provenance_base + im.ra = ra + im.dec = dec + im.save() + images.append(im) + + # TODO: replace with a "from_images" method? + ref = Image.from_images(images) + ref.data = np.mean(np.array([im.data for im in images]), axis=0) + + provenance_extra.process = 'coaddition' + ref.provenance = provenance_extra + ref.save() + + ref_entry = ReferenceEntry() + ref_entry.image = ref + ref_entry.validity_start = Time(50000, format='mjd', scale='utc').isot + ref_entry.validity_end = Time(58500, format='mjd', scale='utc').isot + ref_entry.section_id = 0 + ref_entry.filter = filter + ref_entry.target = target + + with SmartSession() as session: + session.add(ref_entry) + session.commit() + + yield ref_entry + + finally: # cleanup + if ref_entry is not None: + with SmartSession() as session: + ref_entry = session.merge(ref_entry) + ref = ref_entry.image + for im in ref.source_images: + exp = im.exposure + exp.remove_data_from_disk() + im.remove_data_from_disk() + session.delete(exp) + session.delete(im) + ref.remove_data_from_disk() + session.delete(ref) # should also delete ref_entry + + session.commit() + + diff --git a/tests/models/test_image.py b/tests/models/test_image.py index c4bd3ef1..bfd662d9 100644 --- a/tests/models/test_image.py +++ b/tests/models/test_image.py @@ -88,7 +88,7 @@ def test_image_no_null_values(provenance_base): def test_image_enum_values(demo_image, provenance_base): data_filename = None with SmartSession() as session: - demo_image.provenance_id = provenance_base.id + demo_image.provenance = provenance_base with pytest.raises(RuntimeError, match='The image data is not loaded. Cannot save.'): demo_image.save() @@ -98,23 +98,6 @@ def test_image_enum_values(demo_image, provenance_base): assert os.path.exists(data_filename) try: - assert demo_image.combine_method is None - - with pytest.raises(DataError, match='invalid input value for enum image_combine_method: "foo"'): - demo_image.combine_method = 'foo' - session.add(demo_image) - session.commit() - session.rollback() - - for method in ['coadd', 'subtraction']: - demo_image.combine_method = method - session.add(demo_image) - session.commit() - - demo_image.combine_method = 'subtraction' - session.add(demo_image) - session.commit() - with pytest.raises(DataError, match='invalid input value for enum image_type: "foo"'): demo_image.type = 'foo' session.add(demo_image) @@ -135,7 +118,7 @@ def test_image_enum_values(demo_image, provenance_base): def test_image_coordinates(): - image = Image('foo.fits', ra=None, dec=None, nofile=True) + image = Image('coordinates.fits', ra=None, dec=None, nofile=True) assert image.ecllat is None assert image.ecllon is None assert image.gallat is None @@ -144,13 +127,13 @@ def test_image_coordinates(): with pytest.raises(ValueError, match='Object must have RA and Dec set'): image.calculate_coordinates() - image = Image('foo.fits', ra=123.4, dec=None, nofile=True) + image = Image('coordinates.fits', ra=123.4, dec=None, nofile=True) assert image.ecllat is None assert image.ecllon is None assert image.gallat is None assert image.gallon is None - image = Image('foo.fits', ra=123.4, dec=56.78, nofile=True) + image = Image('coordinates.fits', ra=123.4, dec=56.78, nofile=True) assert abs(image.ecllat - 35.846) < 0.01 assert abs(image.ecllon - 111.838) < 0.01 assert abs(image.gallat - 33.542) < 0.01 @@ -176,8 +159,8 @@ def test_image_from_exposure(exposure, provenance_base): assert im.telescope == exposure.telescope assert im.project == exposure.project assert im.target == exposure.target - assert im.combine_method is None - assert not im.is_multi_image + assert not im.is_coadd + assert not im.is_sub assert im.id is None # need to commit to get IDs assert im.exposure_id is None # need to commit to get IDs assert im.source_images == [] @@ -207,7 +190,7 @@ def test_image_from_exposure(exposure, provenance_base): session.rollback() # must add the filepath! - im.filepath = 'foo.fits' + im.filepath = 'foo_exposure.fits' session.add(im) session.commit() @@ -244,29 +227,17 @@ def test_image_with_multiple_source_images(exposure, exposure2, provenance_base) # get a couple of images from exposure objects im1 = Image.from_exposure(exposure, section_id=0) im2 = Image.from_exposure(exposure2, section_id=0) + im2.filter = im1.filter + im2.target = im1.target im1.provenance = provenance_base im1.filepath = 'foo1.fits' im2.provenance = provenance_base im2.filepath = 'foo2.fits' - # make a new image from the two (we still don't have a coadd method for this) - im = Image( - exp_time=im1.exp_time + im2.exp_time, - mjd=im1.mjd, - end_mjd=im2.end_mjd, - filter=im1.filter, - instrument=im1.instrument, - telescope=im1.telescope, - project=im1.project, - target=im1.target, - combine_method='coadd', - section_id=im1.section_id, - ra=im1.ra, - dec=im1.dec, - filepath='foo.fits' - ) - im.source_images = [im1, im2] + # make a coadd image from the two + im = Image.from_images([im1, im2]) + im.filepath = 'foo.fits' im.provenance = provenance_base try: @@ -280,7 +251,7 @@ def test_image_with_multiple_source_images(exposure, exposure2, provenance_base) im_id = im.id assert im_id is not None assert im.exposure_id is None - assert im.is_multi_image + assert im.is_coadd assert im.source_images == [im1, im2] assert np.isclose(im.mid_mjd, (im1.mjd + im2.mjd) / 2) @@ -289,14 +260,81 @@ def test_image_with_multiple_source_images(exposure, exposure2, provenance_base) assert im1_id is not None assert im1.exposure_id is not None assert im1.exposure_id == exposure.id - assert not im1.is_multi_image + assert not im1.is_coadd + assert im1.source_images == [] + + im2_id = im2.id + assert im2_id is not None + assert im2.exposure_id is not None + assert im2.exposure_id == exposure2.id + assert not im2.is_coadd + assert im2.source_images == [] + + finally: # make sure to clean up all images + for id_ in [im_id, im1_id, im2_id]: + if id_ is not None: + with SmartSession() as session: + im = session.scalars(sa.select(Image).where(Image.id == id_)).first() + session.delete(im) + session.commit() + + +def test_image_subtraction(exposure, exposure2, provenance_base): + exposure.update_instrument() + exposure2.update_instrument() + + # make sure exposures are in chronological order... + if exposure.mjd > exposure2.mjd: + exposure, exposure2 = exposure2, exposure + + # get a couple of images from exposure objects + im1 = Image.from_exposure(exposure, section_id=0) + im2 = Image.from_exposure(exposure2, section_id=0) + im2.filter = im1.filter + im2.target = im1.target + + im1.provenance = provenance_base + im1.filepath = 'foo1.fits' + im2.provenance = provenance_base + im2.filepath = 'foo2.fits' + + # make a coadd image from the two + im = Image.from_ref_and_new(im1, im2) + im.filepath = 'foo.fits' + im.provenance = provenance_base + + try: + im_id = None + im1_id = None + im2_id = None + with SmartSession() as session: + session.add(im) + session.commit() + + im_id = im.id + assert im_id is not None + assert im.exposure_id is None + assert im.is_sub + assert im.ref_image == im1 + assert im.ref_image_id == im1.id + assert im.new_image == im2 + assert im.new_image_id == im2.id + assert im.mjd == im2.mjd + assert im.exp_time == im2.exp_time + + # make sure source images are pulled into the database too + im1_id = im1.id + assert im1_id is not None + assert im1.exposure_id is not None + assert im1.exposure_id == exposure.id + assert not im1.is_coadd assert im1.source_images == [] im2_id = im2.id assert im2_id is not None assert im2.exposure_id is not None assert im2.exposure_id == exposure2.id - assert not im2.is_multi_image + assert not im2.is_coadd assert im2.source_images == [] finally: # make sure to clean up all images @@ -310,12 +348,12 @@ def test_image_with_multiple_source_images(exposure, exposure2, provenance_base) def test_image_filename_conventions(demo_image, provenance_base): demo_image.data = np.float32(demo_image.raw_data) - demo_image.provenance_id = provenance_base.id + demo_image.provenance = provenance_base # use the naming convention in the config file demo_image.save() - assert re.search(r'\d{3}/Demo_\d{8}_\d{6}_\d{2}_._\d{3}\.image\.fits', demo_image.get_fullpath()[0]) + assert re.search(r'\d{3}/Demo_\d{8}_\d{6}_\d+_.+_.{6}\.image\.fits', demo_image.get_fullpath()[0]) for f in demo_image.get_fullpath(as_list=True): assert os.path.isfile(f) os.remove(f) @@ -327,7 +365,7 @@ def test_image_filename_conventions(demo_image, provenance_base): try: cfg.set_value('storage.images.name_convention', None) demo_image.save() - assert re.search(r'Demo_\d{8}_\d{6}_\d{2}_._\d{3}\.image\.fits', demo_image.get_fullpath()[0]) + assert re.search(r'Demo_\d{8}_\d{6}_\d+_.+_.{6}\.image\.fits', demo_image.get_fullpath()[0]) for f in demo_image.get_fullpath(as_list=True): assert os.path.isfile(f) os.remove(f) @@ -364,7 +402,7 @@ def test_image_filename_conventions(demo_image, provenance_base): def test_image_multifile(demo_image, provenance_base): demo_image.data = np.float32(demo_image.raw_data) demo_image.flags = np.random.randint(0, 100, size=demo_image.raw_data.shape, dtype=np.uint32) - demo_image.provenance_id = provenance_base.id + demo_image.provenance = provenance_base cfg = config.Config.get() single_fileness = cfg.value('storage.images.single_file') # store initial value @@ -374,7 +412,7 @@ def test_image_multifile(demo_image, provenance_base): cfg.set_value('storage.images.single_file', True) demo_image.save() - assert re.match(r'\d{3}/Demo_\d{8}_\d{6}_\d{2}_._\d{3}\.fits', demo_image.filepath) + assert re.match(r'\d{3}/Demo_\d{8}_\d{6}_\d+_.+_.{6}\.fits', demo_image.filepath) files = demo_image.get_fullpath(as_list=True) assert len(files) == 1 @@ -398,13 +436,13 @@ def test_image_multifile(demo_image, provenance_base): cfg.set_value('storage.images.single_file', False) demo_image.save() - assert re.match(r'\d{3}/Demo_\d{8}_\d{6}_\d{2}_._\d{3}', demo_image.filepath) + assert re.match(r'\d{3}/Demo_\d{8}_\d{6}_\d+_.+_.{6}', demo_image.filepath) fullnames = demo_image.get_fullpath(as_list=True) assert len(fullnames) == 2 assert os.path.isfile(fullnames[0]) - assert re.search(r'\d{3}/Demo_\d{8}_\d{6}_\d{2}_._\d{3}\.image\.fits', fullnames[0]) + assert re.search(r'\d{3}/Demo_\d{8}_\d{6}_\d+_.+_.{6}\.image\.fits', fullnames[0]) with fits.open(fullnames[0]) as hdul: assert len(hdul) == 1 # image data is saved on the primary HDU assert hdul[0].header['NAXIS'] == 2 @@ -412,7 +450,7 @@ def test_image_multifile(demo_image, provenance_base): assert np.array_equal(hdul[0].data, demo_image.data) assert os.path.isfile(fullnames[1]) - assert re.search(r'\d{3}/Demo_\d{8}_\d{6}_\d{2}_._\d{3}\.flags\.fits', fullnames[1]) + assert re.search(r'\d{3}/Demo_\d{8}_\d{6}_\d+_.+_.{6}\.flags\.fits', fullnames[1]) with fits.open(fullnames[1]) as hdul: assert len(hdul) == 1 assert hdul[0].header['NAXIS'] == 2 diff --git a/tests/models/test_provenance.py b/tests/models/test_provenance.py index dea33332..94a20b0d 100644 --- a/tests/models/test_provenance.py +++ b/tests/models/test_provenance.py @@ -108,7 +108,7 @@ def test_provenances(code_version): assert pid1 is not None assert p.unique_hash is not None assert isinstance(p.unique_hash, str) - assert len(p.unique_hash) == 64 + assert len(p.unique_hash) == 20 hash = p.unique_hash p2 = Provenance( @@ -125,7 +125,7 @@ def test_provenances(code_version): assert pid2 is not None assert p2.unique_hash is not None assert isinstance(p2.unique_hash, str) - assert len(p2.unique_hash) == 64 + assert len(p2.unique_hash) == 20 assert p2.unique_hash != hash finally: with SmartSession() as session: @@ -156,7 +156,7 @@ def test_unique_provenance_hash(code_version): assert pid is not None assert p.unique_hash is not None assert isinstance(p.unique_hash, str) - assert len(p.unique_hash) == 64 + assert len(p.unique_hash) == 20 hash = p.unique_hash p2 = Provenance( @@ -203,7 +203,7 @@ def test_upstream_relationship(code_version, provenance_base, provenance_extra): assert pid1 is not None assert p1.unique_hash is not None assert isinstance(p1.unique_hash, str) - assert len(p1.unique_hash) == 64 + assert len(p1.unique_hash) == 20 hash = p1.unique_hash p2 = Provenance( @@ -220,7 +220,7 @@ def test_upstream_relationship(code_version, provenance_base, provenance_extra): new_ids.append(pid2) assert p2.unique_hash is not None assert isinstance(p2.unique_hash, str) - assert len(p2.unique_hash) == 64 + assert len(p2.unique_hash) == 20 # added a new upstream, so the hash should be different assert p2.unique_hash != hash @@ -242,9 +242,9 @@ def test_upstream_relationship(code_version, provenance_base, provenance_extra): assert p3_recovered is not None # check that the downstreams of our fixture provenances have been updated too - base_downstream_ids = [p.id for p in provenance_base.downstreams] - assert all([pid in base_downstream_ids for pid in [pid1, pid2]]) - assert pid2 in [p.id for p in provenance_extra.downstreams] + # base_downstream_ids = [p.id for p in provenance_base.downstreams] + # assert all([pid in base_downstream_ids for pid in [pid1, pid2]]) + # assert pid2 in [p.id for p in provenance_extra.downstreams] finally: session.execute(sa.delete(Provenance).where(Provenance.id.in_(new_ids))) @@ -255,9 +255,12 @@ def test_upstream_relationship(code_version, provenance_base, provenance_extra): cv = session.scalars(sa.select(CodeVersion).where(CodeVersion.id == code_version.id)).first() assert cv is not None - # the deletion of the new provenances should have cascaded to the downstreams - base_downstream_ids = [p.id for p in provenance_base.downstreams] - assert all([pid not in base_downstream_ids for pid in new_ids]) - extra_downstream_ids = [p.id for p in provenance_extra.downstreams] - assert all([pid not in extra_downstream_ids for pid in new_ids]) + # # the deletion of the new provenances should have cascaded to the downstreams + # session.refresh(provenance_base) + # base_downstream_ids = [p.id for p in provenance_base.downstreams] + # assert all([pid not in base_downstream_ids for pid in new_ids]) + # + # session.refresh(provenance_extra) + # extra_downstream_ids = [p.id for p in provenance_extra.downstreams] + # assert all([pid not in extra_downstream_ids for pid in new_ids]) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py new file mode 100644 index 00000000..e3ac9c92 --- /dev/null +++ b/tests/pipeline/test_pipeline.py @@ -0,0 +1,170 @@ +import os + +import sqlalchemy as sa + +from models.base import SmartSession, FileOnDiskMixin +from models.provenance import Provenance +from models.exposure import Exposure +from models.image import Image +from models.source_list import SourceList +from models.world_coordinates import WorldCoordinates +from models.zero_point import ZeroPoint +from models.cutouts import Cutouts +from models.measurements import Measurements + +from pipeline.top_level import Pipeline + + +def match_exposure_to_reference_entry(exposure, reference_entry): + """Make sure the exposure has the same target, project, filter, and section_id as the reference image.""" + exposure.target = reference_entry.target + exposure.project = reference_entry.image.project + exposure.filter = reference_entry.filter + + +def check_datastore_and_database_have_everything(exp_id, sec_id, ref_id, session, ds): + """ + Check that all the required objects are saved on the database + and in the datastore, after running the entire pipeline. + + Parameters + ---------- + exp_id: int + The exposure ID. + sec_id: str or int + The section ID. + ref_id: int + The reference image ID. + session: sqlalchemy.orm.session.Session + The database session + ds: datastore.DataStore + The datastore object + """ + im = session.scalars( + sa.select(Image).where(Image.exposure_id == exp_id, Image.section_id == str(sec_id)) + ).first() + assert im is not None + assert ds.image.id == im.id + + sl = session.scalars( + sa.select(SourceList).where(SourceList.image_id == im.id, SourceList.is_sub.is_(False)) + ).first() + assert sl is not None + assert ds.sources.id == sl.id + + wcs = session.scalars( + sa.select(WorldCoordinates).where(WorldCoordinates.source_list_id == sl.id) + ).first() + assert wcs is not None + assert ds.wcs.id == wcs.id + + zp = session.scalars( + sa.select(ZeroPoint).where(ZeroPoint.source_list_id == sl.id) + ).first() + assert zp is not None + assert ds.zp.id == zp.id + + sub = session.scalars( + sa.select(Image).where(Image.new_image_id == im.id, Image.ref_image_id == ref_id) + ).first() + + assert sub is not None + assert ds.sub_image.id == sub.id + + sl = session.scalars( + sa.select(SourceList).where(SourceList.image_id == sub.id, SourceList.is_sub.is_(True)) + ).first() + + assert sl is not None + assert ds.detections.id == sl.id + + # TODO: add the cutouts and measurements, but we need to produce them first! + + +def test_data_flow(exposure, reference_entry): + """Test that the pipeline runs end-to-end.""" + sec_id = reference_entry.section_id + + ds = None + try: # cleanup the file at the end + # add the exposure to DB and use that ID to run the pipeline + with SmartSession() as session: + reference_entry = session.merge(reference_entry) + match_exposure_to_reference_entry(exposure, reference_entry) + + session.add(exposure) + session.commit() + exp_id = exposure.id + + filename = exposure.get_fullpath() + open(filename, 'a').close() + ref_id = reference_entry.image.id + + p = Pipeline() + ds = p.run(exp_id, sec_id) + + # commit to DB using this session + with SmartSession() as session: + ds.save_and_commit(session=session) + + # use a new session to query for the results + with SmartSession() as session: + # check that everything is in the database + provs = session.scalars(sa.select(Provenance)).all() + assert len(provs) > 0 + prov_processes = [p.process for p in provs] + expected_processes = ['preprocessing', 'extraction', 'astro_cal', 'photo_cal', 'subtraction', 'detection'] + for process in expected_processes: + assert process in prov_processes + + check_datastore_and_database_have_everything(exp_id, sec_id, ref_id, session, ds) + + # feed the pipeline the same data, but missing the upstream data + attributes = ['exposure', 'image', 'sources', 'wcs', 'zp', 'sub_image', 'detections'] + + for i in range(len(attributes)): + for j in range(i): + setattr(ds, attributes[j], None) # get rid of all data up to the current attribute + + ds = p.run(ds) + + # commit to DB using this session + with SmartSession() as session: + ds.save_and_commit(session=session) + + # use a new session to query for the results + with SmartSession() as session: + check_datastore_and_database_have_everything(exp_id, sec_id, ref_id, session, ds) + + print(ds.image.filepath) + print(ds.sub_image.filepath) + # make sure we can remove the data from the end to the beginning and recreate it + for i in range(len(attributes)): + for j in range(i): + # print(f'i= {i}, j= {j}. Removing attribute: {attributes[-j-1]}') + + obj = getattr(ds, attributes[-j-1]) + with SmartSession() as session: + obj = obj.recursive_merge(session=session) + if isinstance(obj, FileOnDiskMixin): + obj.remove_data_from_disk() + session.delete(obj) + session.commit() + + setattr(ds, attributes[-j-1], None) + + ds = p.run(ds) + + # commit to DB using this session + with SmartSession() as session: + ds.save_and_commit(session=session) + + # use a new session to query for the results + with SmartSession() as session: + check_datastore_and_database_have_everything(exp_id, sec_id, ref_id, session, ds) + + finally: + if ds is not None: + ds.delete_everything() + +