Skip to content

Commit

Permalink
Bitflag and loading changes
Browse files Browse the repository at this point in the history
  • Loading branch information
whohensee committed Jun 27, 2024
1 parent 3b36904 commit 9338a5f
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
7 changes: 0 additions & 7 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,6 @@ def remove_data_from_disk(self, remove_folders=True, remove_downstreams=False):
that have remove_data_from_disk() implemented, and call it.
Default is False.
"""
# SCLogger.debug(f"START remove_data_from_disk on {self}")
if self.filepath is not None:
# get the filepath, but don't check if the file exists!
for f in self.get_fullpath(as_list=True, nofile=True):
Expand All @@ -1265,8 +1264,6 @@ def remove_data_from_disk(self, remove_folders=True, remove_downstreams=False):
d[0].delete_list(d, remove_local=True, archive=False, database=False)
except NotImplementedError as e:
pass # if this object does not implement get_downstreams, it is ok
# SCLogger.debug(f"FINISH remove_data_from_disk on {self}")


def delete_from_archive(self, remove_downstreams=False):
"""Delete the file from the archive, if it exists.
Expand Down Expand Up @@ -1348,17 +1345,13 @@ def delete_from_disk_and_database(
if session is None and not commit:
raise RuntimeError("When session=None, commit must be True!")

# breakpoint()
SeeChangeBase.delete_from_database(self, session=session, commit=commit, remove_downstreams=remove_downstreams)

# breakpoint()
self.remove_data_from_disk(remove_folders=remove_folders, remove_downstreams=remove_downstreams)

# breakpoint()
if archive:
self.delete_from_archive(remove_downstreams=remove_downstreams)


# make sure these are set to null just in case we fail
# to commit later on, we will at least know something is wrong
self.filepath_extensions = None
Expand Down
8 changes: 4 additions & 4 deletions models/enums_and_bitflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,22 +410,22 @@ def string_to_bitflag(value, dictionary):
bg_badness_inverse = {EnumConverter.c(v): k for k, v in bg_badness_dict.items()}


# these are the ways a Cutouts object is allowed to be bad
cutouts_badness_dict = {
# these are the ways a Measurements object is allowed to be bad
measurements_badness_dict = {
41: 'cosmic ray',
42: 'ghost',
43: 'satellite',
44: 'offset',
45: 'bad pixel',
46: 'bleed trail',
}
cutouts_badness_inverse = {EnumConverter.c(v): k for k, v in cutouts_badness_dict.items()}
measurements_badness_inverse = {EnumConverter.c(v): k for k, v in measurements_badness_dict.items()}


# join the badness:
data_badness_dict = {}
data_badness_dict.update(image_badness_dict)
data_badness_dict.update(cutouts_badness_dict)
data_badness_dict.update(measurements_badness_dict)
data_badness_dict.update(source_list_badness_dict)
data_badness_dict.update(psf_badness_dict)
data_badness_dict.update(bg_badness_dict)
Expand Down
2 changes: 2 additions & 0 deletions pipeline/cutting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def run(self, *args, **kwargs):

# try to find some measurements in memory or in the database:
cutouts = ds.get_cutouts(prov, session=session)
if cutouts is not None:
cutouts.load_all_co_data()

if cutouts is None or cutouts.co_dict == {}:

Expand Down
1 change: 1 addition & 0 deletions pipeline/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,7 @@ def get_cutouts(self, provenance=None, session=None):
provenance = self._get_provenance_for_an_upstream(process_name, session)

if self.cutouts is not None:
self.cutouts.load_all_co_data()
if self.cutouts.co_dict == {}:
self.cutouts = None # TODO: what about images that actually don't have any detections?

Expand Down
10 changes: 6 additions & 4 deletions pipeline/measuring.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,16 @@ def run(self, *args, **kwargs):
raise ValueError(f'Cannot find a source list corresponding to the datastore inputs: {ds.get_inputs()}')

cutouts = ds.get_cutouts(session=session)
if cutouts is not None:
cutouts.load_all_co_data()

# prepare the filter bank for this batch of cutouts
if self._filter_psf_fwhm is None or self._filter_psf_fwhm != cutouts.sources.image.get_psf().fwhm_pixels:
self.make_filter_bank(cutouts.co_dict["source_index_0"]["sub_data"].shape[0], cutouts.sources.image.get_psf().fwhm_pixels)

# go over each cutouts object and produce a measurements object
measurements_list = []
for key, co_dict in cutouts.co_dict.items():
for key, co_subdict in cutouts.co_dict.items():
m = Measurements(cutouts=cutouts)
# make sure to remember which cutout belongs to this measurement,
# before either of them is in the DB and then use the cutouts_id instead
Expand All @@ -231,7 +233,7 @@ def run(self, *args, **kwargs):
ignore_bits |= 2 ** BitFlagConverter.convert(badness)

# remove the bad pixels that we want to ignore
flags = co_dict['sub_flags'].astype('uint16') & ~np.array(ignore_bits).astype('uint16')
flags = co_subdict['sub_flags'].astype('uint16') & ~np.array(ignore_bits).astype('uint16')

annulus_radii_pixels = self.pars.annulus_radii
if self.pars.annulus_units == 'fwhm':
Expand All @@ -240,8 +242,8 @@ def run(self, *args, **kwargs):

# TODO: consider if there are any additional parameters that photometry needs
output = iterative_cutouts_photometry(
co_dict['sub_data'],
co_dict['sub_weight'],
co_subdict['sub_data'],
co_subdict['sub_weight'],
flags,
radii=m.aper_radii,
annulus=annulus_radii_pixels,
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_make_save_load_cutouts(decam_detection_list, cutter):
c2 = Cutouts()
c2.filepath = ds.cutouts.filepath
c2.sources = ds.cutouts.sources # necessary for co_dict
c2.load() # explicitly load co_dict
c2.load_all_co_data() # explicitly load co_dict

co_subdict2 = c2.co_dict[subdict_key]

Expand Down Expand Up @@ -85,6 +85,7 @@ def test_make_save_load_cutouts(decam_detection_list, cutter):
ds.cutouts = session.merge(ds.cutouts)
session.commit() # QUESTION: does this necessitate cleanup in the finally block?

ds.cutouts.load_all_co_data() # need to re-load after merge
assert ds.cutouts is not None
assert len(ds.cutouts.co_dict) > 0

Expand All @@ -96,6 +97,7 @@ def test_make_save_load_cutouts(decam_detection_list, cutter):
loaded_cutouts = loaded_cutouts[0]

# make sure data is correct
loaded_cutouts.load_all_co_data()
co_subdict = loaded_cutouts.co_dict[subdict_key]
for im in ['sub', 'ref', 'new']:
for att in ['data', 'weight', 'flags']:
Expand Down

0 comments on commit 9338a5f

Please sign in to comment.