Skip to content

Commit

Permalink
Add "load_defaults" option to upsert, use it in DataStore.save_and_co…
Browse files Browse the repository at this point in the history
…mmit
  • Loading branch information
rknop committed Aug 22, 2024
1 parent af7134b commit 869dbce
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 15 deletions.
34 changes: 31 additions & 3 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,12 @@ def _get_cols_and_vals_for_insert( self ):
continue
elif col.name == 'modified':
val = datetime.datetime.now( tz=datetime.timezone.utc )

if isinstance( col.type, sqlalchemy.dialects.postgresql.json.JSONB ) and ( val is not None ):
val = json.dumps( val )
elif isinstance( val, np.ndarray ):
val = list( val )

if ( ( ( col.server_default is not None ) and ( col.nullable ) and ( val is None ) )
or
( val is not None )
Expand Down Expand Up @@ -488,7 +492,9 @@ def insert( self, session=None ):
# other unintended consequences calling that SQLA function
# might have. Third, now that we've moved defaults to be
# database-side defaults, we'll get errors from SQLA if those
# fields aren't filled by trying to do an add.
# fields aren't filled by trying to do an add, whereas we
# should be find with that as the database will just load
# the defaults.
#
# In any event, doing this manually dodges any weirdness associated
# with objects attached, or not attached, to sessions.
Expand All @@ -502,7 +508,7 @@ def insert( self, session=None ):
sess.commit()


def upsert( self, session=None ):
def upsert( self, session=None, load_defaults=False ):
"""Insert an object into the database, or update it if it's already there (using _id as the primary key).
Will *not* update self's fields with server default values!
Expand All @@ -527,6 +533,11 @@ def upsert( self, session=None ):
session: SQLAlchemy Session, default None
Usually you don't want to pass this.
load_defaults: bool, default False
Normally, will *not* update self's fields with server
default values. Set this to True for that to happen. (This
will trigger an additional read from the database.)
"""

# Doing this manually because I don't think SQLAlchemy has a
Expand Down Expand Up @@ -560,9 +571,17 @@ def upsert( self, session=None ):
sess.execute( sa.text( q ), subdict )
sess.commit()

if load_defaults:
dbobj = self.__class__.get_by_id( self.id, session=sess )
for col in sa.inspect( self.__class__ ).c:
if ( ( col.name == 'modified' ) or
( ( col.server_default is not None ) and ( getattr( self, col.name ) is None ) )
):
setattr( self, col.name, getattr( dbobj, col.name ) )


@classmethod
def upsert_list( cls, objects, session=None ):
def upsert_list( cls, objects, session=None, load_defaults=False ):
"""Like upsert, but for a bunch of objects in a list, and tries to be efficient about it.
Do *not* use this with classes that have things like association
Expand Down Expand Up @@ -596,6 +615,15 @@ def upsert_list( cls, objects, session=None ):
sess.execute( sa.text( q ), subdict )
sess.commit()

if load_defaults:
for obj in objects:
dbobj = obj.__class__.get_by_id( obj.id, session=sess )
for col in sa.inspect( obj.__class__).c:
if ( ( col.name == 'modified' ) or
( ( col.server_default is not None ) and ( getattr( obj, col.name ) is None ) )
):
setattr( obj, col.name, getattr( dbobj, col.name ) )


def _delete_from_database( self ):
"""Remove the object from the database. Don't call this, call delete_from_disk_and_database.
Expand Down
26 changes: 14 additions & 12 deletions pipeline/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,10 +1726,12 @@ def save_and_commit(self,
It will run an upsert on the database record for all data
products. This means that if the object is not in the databse,
it will get added. If it already is in the databse, its fields
will be updated with what's in the objects in the DataStore.
Most of the time, this should be a null operation, as if we're
not inserting, we have all the fields that were already loaded.
it will get added. (In this case, the object is then reloaded
back from the database, so that the database-default fields will
be filled.) If it already is in the database, its fields will
be updated with what's in the objects in the DataStore. Most of
the time, this should be a null operation, as if we're not
inserting, we have all the fields that were already loaded.
However, it does matter for self.image, as some fields (such as
background level, fwhm, zp) get set during processes that happen
after the image's record in the database is first created.
Expand Down Expand Up @@ -1860,7 +1862,7 @@ def save_and_commit(self,
# anyway.
if self.exposure is not None:
SCLogger.debug( "save_and_commit upserting exposure" )
self.exposure.upsert()
self.exposure.upsert( load_defaults=True )
# commits.append( 'exposure' )
# exposure isn't in the commit bitflag

Expand All @@ -1869,15 +1871,15 @@ def save_and_commit(self,
if self.exposure is not None:
self.image.exposure_id = self.exposure.id
SCLogger.debug( "save_and_commit upserting image" )
self.image.upsert()
self.image.upsert( load_defaults=True )
commits.append( 'image' )

# SourceList
if self.sources is not None:
if self.image is not None:
self.sources.image_id = self.image.id
SCLogger.debug( "save_and_commit upserting sources" )
self.sources.upsert()
self.sources.upsert( load_defaults=True )
commits.append( 'sources' )

# SourceList siblings
Expand All @@ -1886,12 +1888,12 @@ def save_and_commit(self,
if self.sources is not None:
setattr( getattr( self, att ), 'sources_id', self.sources.id )
SCLogger.debug( f"save_and_commit upserting {att}" )
getattr( self, att ).upsert()
getattr( self, att ).upsert( load_defaults=True )
commits.append( att )

# subtraction Image
if self.sub_image is not None:
self.sub_image.upsert()
self.sub_image.upsert( load_defaults=True )
SCLogger.debug( "save_and_commit upserting sub_image" )
commits.append( 'sub_image' )

Expand All @@ -1900,23 +1902,23 @@ def save_and_commit(self,
if self.sub_image is not None:
self.detections.sources_id = self.sub_image.id
SCLogger.debug( "save_and_commit detections" )
self.detections.upsert()
self.detections.upsert( load_defaults=True )
commits.append( 'detections' )

# cutouts
if self.cutouts is not None:
if self.detections is not None:
self.cutouts.detections_id = self.detections.id
SCLogger.debug( "save_and_commit upserting cutouts" )
self.cutouts.upsert()
self.cutouts.upsert( load_defaults=True )
commits.append( 'cutouts' )

# measurements
if ( self.measurements is not None ) and ( len(self.measurements) > 0 ):
if self.cutouts is not None:
for m in self.measurements:
m.cutouts_id = self.cutouts.id
Measurements.upsert_list( self.measurements )
Measurements.upsert_list( self.measurements, load_defaults=True )
SCLogger.debug( "save_and_commit measurements" )
commits.append( 'measurements' )

Expand Down
16 changes: 16 additions & 0 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,22 @@ def test_upsert( provenance_base ):
assert len(multifound) == 2
assert set( [ i.id for i in multifound ] ) == set( uuidstodel )

# Now verify that server-side values *do* get updated if we ask for it

image.upsert( load_defaults=True )
assert image.created_at is not None
assert image.modified is not None
assert image.created_at < image.modified
assert image._format == 1
assert image.preproc_bitflag == 0

# Make sure they don't always revert to defaults
image._format = 2
image.upsert( load_defaults=True )
assert image._format == 2
found = Image.get_by_id( image.id )
assert found._format == 2

finally:
# Clean up
with SmartSession() as sess:
Expand Down

0 comments on commit 869dbce

Please sign in to comment.