From 869dbce0b568bf5d68806c05625d669b28719a18 Mon Sep 17 00:00:00 2001 From: Rob Knop Date: Thu, 22 Aug 2024 14:00:41 -0700 Subject: [PATCH] Add "load_defaults" option to upsert, use it in DataStore.save_and_commit --- models/base.py | 34 +++++++++++++++++++++++++++++++--- pipeline/data_store.py | 26 ++++++++++++++------------ tests/models/test_base.py | 16 ++++++++++++++++ 3 files changed, 61 insertions(+), 15 deletions(-) diff --git a/models/base.py b/models/base.py index 6dca9a50..ce0c5cc6 100644 --- a/models/base.py +++ b/models/base.py @@ -443,8 +443,12 @@ def _get_cols_and_vals_for_insert( self ): continue elif col.name == 'modified': val = datetime.datetime.now( tz=datetime.timezone.utc ) + if isinstance( col.type, sqlalchemy.dialects.postgresql.json.JSONB ) and ( val is not None ): val = json.dumps( val ) + elif isinstance( val, np.ndarray ): + val = list( val ) + if ( ( ( col.server_default is not None ) and ( col.nullable ) and ( val is None ) ) or ( val is not None ) @@ -488,7 +492,9 @@ def insert( self, session=None ): # other unintended consequences calling that SQLA function # might have. Third, now that we've moved defaults to be # database-side defaults, we'll get errors from SQLA if those - # fields aren't filled by trying to do an add. + # fields aren't filled by trying to do an add, whereas we + # should be find with that as the database will just load + # the defaults. # # In any event, doing this manually dodges any weirdness associated # with objects attached, or not attached, to sessions. @@ -502,7 +508,7 @@ def insert( self, session=None ): sess.commit() - def upsert( self, session=None ): + def upsert( self, session=None, load_defaults=False ): """Insert an object into the database, or update it if it's already there (using _id as the primary key). Will *not* update self's fields with server default values! @@ -527,6 +533,11 @@ def upsert( self, session=None ): session: SQLAlchemy Session, default None Usually you don't want to pass this. + load_defaults: bool, default False + Normally, will *not* update self's fields with server + default values. Set this to True for that to happen. (This + will trigger an additional read from the database.) + """ # Doing this manually because I don't think SQLAlchemy has a @@ -560,9 +571,17 @@ def upsert( self, session=None ): sess.execute( sa.text( q ), subdict ) sess.commit() + if load_defaults: + dbobj = self.__class__.get_by_id( self.id, session=sess ) + for col in sa.inspect( self.__class__ ).c: + if ( ( col.name == 'modified' ) or + ( ( col.server_default is not None ) and ( getattr( self, col.name ) is None ) ) + ): + setattr( self, col.name, getattr( dbobj, col.name ) ) + @classmethod - def upsert_list( cls, objects, session=None ): + def upsert_list( cls, objects, session=None, load_defaults=False ): """Like upsert, but for a bunch of objects in a list, and tries to be efficient about it. Do *not* use this with classes that have things like association @@ -596,6 +615,15 @@ def upsert_list( cls, objects, session=None ): sess.execute( sa.text( q ), subdict ) sess.commit() + if load_defaults: + for obj in objects: + dbobj = obj.__class__.get_by_id( obj.id, session=sess ) + for col in sa.inspect( obj.__class__).c: + if ( ( col.name == 'modified' ) or + ( ( col.server_default is not None ) and ( getattr( obj, col.name ) is None ) ) + ): + setattr( obj, col.name, getattr( dbobj, col.name ) ) + def _delete_from_database( self ): """Remove the object from the database. Don't call this, call delete_from_disk_and_database. diff --git a/pipeline/data_store.py b/pipeline/data_store.py index 98afef57..57b80efb 100644 --- a/pipeline/data_store.py +++ b/pipeline/data_store.py @@ -1726,10 +1726,12 @@ def save_and_commit(self, It will run an upsert on the database record for all data products. This means that if the object is not in the databse, - it will get added. If it already is in the databse, its fields - will be updated with what's in the objects in the DataStore. - Most of the time, this should be a null operation, as if we're - not inserting, we have all the fields that were already loaded. + it will get added. (In this case, the object is then reloaded + back from the database, so that the database-default fields will + be filled.) If it already is in the database, its fields will + be updated with what's in the objects in the DataStore. Most of + the time, this should be a null operation, as if we're not + inserting, we have all the fields that were already loaded. However, it does matter for self.image, as some fields (such as background level, fwhm, zp) get set during processes that happen after the image's record in the database is first created. @@ -1860,7 +1862,7 @@ def save_and_commit(self, # anyway. if self.exposure is not None: SCLogger.debug( "save_and_commit upserting exposure" ) - self.exposure.upsert() + self.exposure.upsert( load_defaults=True ) # commits.append( 'exposure' ) # exposure isn't in the commit bitflag @@ -1869,7 +1871,7 @@ def save_and_commit(self, if self.exposure is not None: self.image.exposure_id = self.exposure.id SCLogger.debug( "save_and_commit upserting image" ) - self.image.upsert() + self.image.upsert( load_defaults=True ) commits.append( 'image' ) # SourceList @@ -1877,7 +1879,7 @@ def save_and_commit(self, if self.image is not None: self.sources.image_id = self.image.id SCLogger.debug( "save_and_commit upserting sources" ) - self.sources.upsert() + self.sources.upsert( load_defaults=True ) commits.append( 'sources' ) # SourceList siblings @@ -1886,12 +1888,12 @@ def save_and_commit(self, if self.sources is not None: setattr( getattr( self, att ), 'sources_id', self.sources.id ) SCLogger.debug( f"save_and_commit upserting {att}" ) - getattr( self, att ).upsert() + getattr( self, att ).upsert( load_defaults=True ) commits.append( att ) # subtraction Image if self.sub_image is not None: - self.sub_image.upsert() + self.sub_image.upsert( load_defaults=True ) SCLogger.debug( "save_and_commit upserting sub_image" ) commits.append( 'sub_image' ) @@ -1900,7 +1902,7 @@ def save_and_commit(self, if self.sub_image is not None: self.detections.sources_id = self.sub_image.id SCLogger.debug( "save_and_commit detections" ) - self.detections.upsert() + self.detections.upsert( load_defaults=True ) commits.append( 'detections' ) # cutouts @@ -1908,7 +1910,7 @@ def save_and_commit(self, if self.detections is not None: self.cutouts.detections_id = self.detections.id SCLogger.debug( "save_and_commit upserting cutouts" ) - self.cutouts.upsert() + self.cutouts.upsert( load_defaults=True ) commits.append( 'cutouts' ) # measurements @@ -1916,7 +1918,7 @@ def save_and_commit(self, if self.cutouts is not None: for m in self.measurements: m.cutouts_id = self.cutouts.id - Measurements.upsert_list( self.measurements ) + Measurements.upsert_list( self.measurements, load_defaults=True ) SCLogger.debug( "save_and_commit measurements" ) commits.append( 'measurements' ) diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 87dad15a..6fce2188 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -176,6 +176,22 @@ def test_upsert( provenance_base ): assert len(multifound) == 2 assert set( [ i.id for i in multifound ] ) == set( uuidstodel ) + # Now verify that server-side values *do* get updated if we ask for it + + image.upsert( load_defaults=True ) + assert image.created_at is not None + assert image.modified is not None + assert image.created_at < image.modified + assert image._format == 1 + assert image.preproc_bitflag == 0 + + # Make sure they don't always revert to defaults + image._format = 2 + image.upsert( load_defaults=True ) + assert image._format == 2 + found = Image.get_by_id( image.id ) + assert found._format == 2 + finally: # Clean up with SmartSession() as sess: