Skip to content

Commit

Permalink
Provenance Tagging (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
rknop authored Jul 31, 2024
1 parent 1d54209 commit 480c291
Show file tree
Hide file tree
Showing 22 changed files with 1,274 additions and 296 deletions.
45 changes: 45 additions & 0 deletions alembic/versions/2024_07_25_1851-05bb57675701_provenancetag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""provenancetag
Revision ID: 05bb57675701
Revises: d86b7dee2172
Create Date: 2024-07-25 18:51:53.756271
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '05bb57675701'
down_revision = 'd86b7dee2172'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('provenance_tags',
sa.Column('tag', sa.String(), nullable=False),
sa.Column('provenance_id', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('modified', sa.DateTime(timezone=True), nullable=False),
sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False),
sa.ForeignKeyConstraint(['provenance_id'], ['provenances.id'], name='provenance_tags_provenance_id_fkey', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('tag', 'provenance_id', name='_provenancetag_prov_tag_uc')
)
op.create_index(op.f('ix_provenance_tags_created_at'), 'provenance_tags', ['created_at'], unique=False)
op.create_index(op.f('ix_provenance_tags_id'), 'provenance_tags', ['id'], unique=False)
op.create_index(op.f('ix_provenance_tags_provenance_id'), 'provenance_tags', ['provenance_id'], unique=False)
op.create_index(op.f('ix_provenance_tags_tag'), 'provenance_tags', ['tag'], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_provenance_tags_tag'), table_name='provenance_tags')
op.drop_index(op.f('ix_provenance_tags_provenance_id'), table_name='provenance_tags')
op.drop_index(op.f('ix_provenance_tags_id'), table_name='provenance_tags')
op.drop_index(op.f('ix_provenance_tags_created_at'), table_name='provenance_tags')
op.drop_table('provenance_tags')
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ pipeline:
save_before_subtraction: true
# automatically save all the products at the end of the pipeline run
save_at_finish: true
# the ProvenanceTag that the products of the pipline should be associated with
provenance_tag: current

preprocessing:
# these steps need to be done on the images: either they came like that or we do it in the pipeline
Expand Down
8 changes: 4 additions & 4 deletions improc/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ def strip_wcs_keywords( hdr ):
"""

basematch = re.compile( "^C(RVAL|RPIX|UNIT|DELT|TYPE)[12]$" )
cdmatch = re.compile( "^CD[12]_[12]$" )
sipmatch = re.compile( "^[AB]P?_(ORDER|(\d+)_(\d+))$" )
tpvmatch = re.compile( "^P[CV]\d+_\d+$" )
basematch = re.compile( r"^C(RVAL|RPIX|UNIT|DELT|TYPE)[12]$" )
cdmatch = re.compile( r"^CD[12]_[12]$" )
sipmatch = re.compile( r"^[AB]P?_(ORDER|(\d+)_(\d+))$" )
tpvmatch = re.compile( r"^P[CV]\d+_\d+$" )

tonuke = set()
for kw in hdr.keys():
Expand Down
7 changes: 4 additions & 3 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def get_all_database_objects(display=False, session=None):
A dictionary with the object class names as keys and the IDs list as values.
"""
from models.provenance import Provenance, CodeVersion, CodeHash
from models.provenance import Provenance, ProvenanceTag, CodeVersion, CodeHash
from models.datafile import DataFile
from models.knownexposure import KnownExposure, PipelineWorker
from models.exposure import Exposure
from models.image import Image
from models.source_list import SourceList
Expand All @@ -214,10 +215,10 @@ def get_all_database_objects(display=False, session=None):
from models.user import AuthUser, PasswordLink

models = [
CodeHash, CodeVersion, Provenance, DataFile, Exposure, Image,
CodeHash, CodeVersion, Provenance, ProvenanceTag, DataFile, Exposure, Image,
SourceList, PSF, WorldCoordinates, ZeroPoint, Cutouts, Measurements, Object,
CalibratorFile, CalibratorFileDownloadLock, CatalogExcerpt, Reference, SensorSection,
AuthUser, PasswordLink
AuthUser, PasswordLink, KnownExposure, PipelineWorker
]

output = {}
Expand Down
32 changes: 32 additions & 0 deletions models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,40 @@ def merge_all(self, session):
Must provide a session to merge into. Need to commit at the end.
Returns the merged image with all its products on the same session.
DEVELOPER NOTE: changing what gets merged in this function
requires a corresponding change in
pipeline/data_store.py::DataStore.save_and_commit
"""
new_image = self.safe_merge(session=session)

# Note -- this next block of code is useful for trying to debug
# sqlalchemy weirdness. However, because it calls the __repr__
# method of various objects, it actually causes tests to fail.
# In particular, there are tests that use 'ZTF' as the instrument,
# but the code has no ZTF instrument defined, so calling
# Image.__repr__ throws an error. As such, comment the
# code out below, but leave it here in case somebody wants
# to temporarily re-enable it for debugging purposes.
#
# import io
# strio = io.StringIO()
# strio.write( "In image.merge_all; objects in session:\n" )
# if len( session.new ) > 0 :
# strio.write( " NEW:\n" )
# for obj in session.new:
# strio.write( f" {obj}\n" )
# if len( session.dirty ) > 0:
# strio.write( " DIRTY:\n" )
# for obj in session.dirty:
# strio.write( f" {obj}\n" )
# if len( session.deleted ) > 0:
# strio.write( " DELETED:\n" )
# for obj in session.deleted:
# strio.write( f" {obj}\n" )
# SCLogger.debug( strio.getvalue() )

session.flush() # make sure new_image gets an ID

if self.sources is not None:
Expand Down
146 changes: 145 additions & 1 deletion models/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
import json
import base64
import hashlib
from collections import defaultdict
import sqlalchemy as sa
import sqlalchemy.orm as orm
from sqlalchemy import event
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.exc import IntegrityError
from sqlalchemy.schema import UniqueConstraint

from util.util import get_git_hash

import models.base
from models.base import Base, SeeChangeBase, SmartSession, safe_merge
from models.base import Base, AutoIDMixin, SeeChangeBase, SmartSession, safe_merge


class CodeHash(Base):
Expand Down Expand Up @@ -375,3 +378,144 @@ def insert_new_dataset(mapper, connection, target):
foreign_keys="Provenance.code_version_id",
passive_deletes=True,
)


class ProvenanceTagExistsError(Exception):
pass

class ProvenanceTag(Base, AutoIDMixin):
"""A human-readable tag to associate with provenances.
A well-defined provenane tag will have a provenance defined for every step, but there will
only be a *single* provenance for each step (except for refrenceing, where there could be
multiple provenances defined). The class method validate can check this for duplicates.
"""

__tablename__ = "provenance_tags"

__table_args__ = ( UniqueConstraint( 'tag', 'provenance_id', name='_provenancetag_prov_tag_uc' ), )

tag = sa.Column(
sa.String,
nullable=False,
index=True,
doc='Human-readable tag name; one tag has many provenances associated with it.'
)

provenance_id = sa.Column(
sa.ForeignKey( 'provenances.id', ondelete="CASCADE", name='provenance_tags_provenance_id_fkey' ),
index=True,
doc='Provenance ID. Each tag/process should only have one provenance.'
)

provenance = orm.relationship(
'Provenance',
cascade='save-update, merge, refresh-expire, expunge',
lazy='selectin',
doc=( "Provenance" )
)

def __repr__( self ):
return ( '<ProvenanceTag('
f'tag={self.tag}, '
f'provenance_id={self.provenance_id}>' )

@classmethod
def newtag( cls, tag, provs, session=None ):
"""Add a new ProvenanceTag. Will thrown an error if it already exists.
Usually, this is called from pipeline.top_level.make_provenance_tree, not directly.
Always commits.
Parameters
----------
tag: str
The human-readable provenance tag. For cleanliness, should be ASCII, no spaces.
provs: list of str or Provenance
The provenances to include in this tag. Usually, you want to make sure to include
a provenance for every process in the pipeline: exposure, referencing, preprocessing,
extraction, subtraction, detection, cutting, measuring, [TODO MORE: deepscore, alert]
-oo- load_exposure, download, import_image, alignment or aligning, coaddition
"""

with SmartSession( session ) as sess:
# Get all the provenance IDs we're going to insert
provids = set()
for prov in provs:
if isinstance( prov, Provenance ):
provids.add( prov.id )
elif isinstance( prov, str ):
provobj = sess.get( Provenance, prov )
if provobj is None:
raise ValueError( f"Unknown Provenance ID {prov}" )
provids.add( provobj.id )
else:
raise TypeError( f"Everything in the provs list must be Provenance or str, not {type(prov)}" )

try:
# Make sure that this tag doesn't already exist. To avoid race
# conditions of two processes creating it at once (which,
# given how we expect the code to be used, should probably
# not happen in practice), lock the table before searching
# and only unlock after inserting.
sess.connection().execute( sa.text( "LOCK TABLE provenance_tags" ) )
current = sess.query( ProvenanceTag ).filter( ProvenanceTag.tag == tag )
if current.count() != 0:
sess.rollback()
raise ProvenanceTagExistsError( f"ProvenanceTag {tag} already exists." )

for provid in provids:
sess.add( ProvenanceTag( tag=tag, provenance_id=provid ) )

sess.commit()
finally:
# Make sure no lock is left behind; exiting the with block
# ought to do this, but be paranoid.
sess.rollback()

@classmethod
def validate( cls, tag, processes=None, session=None ):
"""Verify that a given tag doesn't have multiply defined processes.
One exception: referenceing can have multiply defined processes.
Raises an exception if things don't work.
Parameters
----------
tag: str
The tag to validate
processes: list of str
The processes to make sure are present. If None, won't make sure
that any processes are present, will just make sure there are no
duplicates.
"""

repeatok = { 'referencing' }

with SmartSession( session ) as sess:
ptags = ( sess.query( (ProvenanceTag.id,Provenance.process) )
.filter( ProvenanceTag.provenance_id==Provenance.id )
.filter( ProvenanceTag.tag==tag )
).all()

count = defaultdict( lambda: 0 )
for ptagid, process in ptags:
count[ process ] += 1

multiples = [ i for i in count.keys() if count[i] > 1 and i not in repeatok ]
if len(multiples) > 0:
raise ValueError( f"Database integrity error: ProcessTag {tag} has more than one "
f"provenance for processes {multiples}" )

if processes is not None:
missing = [ i for i in processes if i not in count.keys() ]
if len( missing ) > 0:
raise ValueError( f"Some processes missing from ProcessTag {tag}: {missing}" )
15 changes: 13 additions & 2 deletions pipeline/data_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import warnings
import datetime
import sqlalchemy as sa
Expand Down Expand Up @@ -1379,6 +1380,9 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False,
True), as the image headers get "first-look" values, not
necessarily the latest and greatest if we tune either process.
DEVELOPER NOTE: this code has to stay synced properly with
models/image.py::Image.merge_all
Parameters
----------
exists_ok: bool, default False
Expand Down Expand Up @@ -1431,8 +1435,15 @@ def save_and_commit(self, exists_ok=False, overwrite=True, no_archive=False,
if obj is None:
continue

SCLogger.debug( f'save_and_commit considering a {obj.__class__.__name__} with filepath '
f'{obj.filepath if isinstance(obj,FileOnDiskMixin) else "<none>"}' )
strio = io.StringIO()
strio.write( f"save_and_commit of {att} considering a {obj.__class__.__name__}" )
if isinstance( obj, FileOnDiskMixin ):
strio.write( f" with filepath {obj.filepath}" )
elif isinstance( obj, list ):
strio.write( f" of types {[type(i) for i in obj]}" )
SCLogger.debug( strio.getvalue() )
# SCLogger.debug( f'save_and_commit of {att} considering a {obj.__class__.__name__} with filepath '
# f'{obj.filepath if isinstance(obj,FileOnDiskMixin) else "<none>"}' )

if isinstance(obj, FileOnDiskMixin):
mustsave = True
Expand Down
Loading

0 comments on commit 480c291

Please sign in to comment.