Skip to content

Commit

Permalink
Change to Co_Dict class and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
whohensee committed Jun 27, 2024
1 parent 9338a5f commit 66857ca
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 131 deletions.
85 changes: 46 additions & 39 deletions models/cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,32 @@
SeeChangeBase,
AutoIDMixin,
FileOnDiskMixin,
SpatiallyIndexed,
HasBitFlagBadness,
)
from models.enums_and_bitflags import CutoutsFormatConverter, cutouts_badness_inverse
from models.enums_and_bitflags import CutoutsFormatConverter
from models.source_list import SourceList


class Co_Dict(dict):
"""Cutouts Dictionary used in Cutouts to store dictionaries which hold data arrays
for individual cutouts. Acts as a normal dictionary, except when a key is passed
using bracket notation (such as "co_dict[source_index_7]"), if that key is not present
in the Co_dict then it will search on disk for the requested data, and if found
will silently load that data and return it.
Must be assigned a Cutouts object to its cutouts attribute so that it knows
how to look for data.
"""
def __init__(self, *args, **kwargs):
self.cutouts = None # this must be assigned before use
super().__init__(self, *args, **kwargs)

def __getitem__(self, key):
if key not in self.keys():
# check if the key exists on disk
if self.cutouts.filepath is not None:
self.cutouts.load_one_co_dict(key)
return super().__getitem__(key)

class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, HasBitFlagBadness):

__tablename__ = 'cutouts'
Expand Down Expand Up @@ -129,7 +148,8 @@ def __init__(self, *args, **kwargs):
self._new_weight = None
self._new_flags = None

self._co_dict = {}
self.co_dict = Co_Dict()
self.co_dict.cutouts = self

self._bitflag = 0

Expand Down Expand Up @@ -159,7 +179,8 @@ def init_on_load(self):
self._new_weight = None
self._new_flags = None

self._co_dict = {}
self.co_dict = Co_Dict()
self.co_dict.cutouts = self

def __repr__(self):
return (
Expand All @@ -180,26 +201,12 @@ def get_data_dict_attributes(include_optional=True): # WHPR could rename get_da

return names

@property
def co_dict( self, ):
# Ok because of partial lazy loading with hdf5 and measurements only wanting their row,
# this one is complicated. I have set that if you use this attribute, it will ENSURE
# that you are given the entire dictionary, including checking it is the proper length
# using the sourcelist. co_dict_noload gives the current dict.
def load_all_co_data(self):
if self.sources.num_sources is None:
raise ValueError("The detections of this cutouts has no num_sources attr")
proper_length = self.sources.num_sources
if len(self._co_dict) != proper_length and self.filepath is not None:
if len(self.co_dict) != proper_length and self.filepath is not None:
self.load()
return self._co_dict

@co_dict.setter
def co_dict( self, value ):
self._co_dict = value

@property
def co_dict_noload(self):
return self._co_dict

@staticmethod
def from_detections(detections, provenance=None, **kwargs):
Expand Down Expand Up @@ -260,13 +267,13 @@ def invent_filepath(self):

return filename

def _save_dataset_dict_to_hdf5(self, co_dict, file, groupname):
def _save_dataset_dict_to_hdf5(self, co_subdict, file, groupname):
"""Save the one co_subdict from the co_dict of this Cutouts
into an HDF5 group for an open file.
Parameters
----------
co_dict: dict
co_subdict: dict
The subdict containing the data for a single cutout
file: h5py.File
The open HDF5 file to save to.
Expand All @@ -277,7 +284,7 @@ def _save_dataset_dict_to_hdf5(self, co_dict, file, groupname):
del file[groupname]

for key in self.get_data_dict_attributes():
data = co_dict.get(key)
data = co_subdict.get(key)

if data is not None:
file.create_dataset(
Expand All @@ -298,17 +305,17 @@ def save(self, filename=None, overwrite=True, **kwargs):
kwargs: dict
Any additional keyword arguments to pass to the FileOnDiskMixin.save method.
"""
if self._co_dict == {}:
if len(self.co_dict) == 0:
return None # do nothing

proper_length = self.sources.num_sources
if len(self._co_dict) != proper_length:
raise ValueError(f"Trying to save cutouts dict with {len(self._co_dict)}"
if len(self.co_dict) != proper_length:
raise ValueError(f"Trying to save cutouts dict with {len(self.co_dict)}"
f" subdicts, but SourceList has {proper_length} sources")

for key, value in self._co_dict.items():
for key, value in self.co_dict.items():
if not isinstance(value, dict):
raise TypeError("Each entry of _co_dict must be a dictionary")
raise TypeError("Each entry of co_dict must be a dictionary")

if filename is None:
filename = self.invent_filepath()
Expand All @@ -323,7 +330,7 @@ def save(self, filename=None, overwrite=True, **kwargs):

if self.format == 'hdf5':
with h5py.File(fullname, 'a') as file:
for key, value in self._co_dict.items():
for key, value in self.co_dict.items():
self._save_dataset_dict_to_hdf5(value, file, key)
elif self.format == 'fits':
raise NotImplementedError('Saving cutouts to fits is not yet implemented.')
Expand All @@ -344,17 +351,17 @@ def _load_dataset_dict_from_hdf5(self, file, groupname):
file: h5py.File
The open HDF5 file to load from.
groupname: str
The name of the group to load from. This should be "source_<number>"
The name of the group to load from. This should be "source_index_<number>"
"""

co_dict = {}
co_subdict = {}
found_data = False
for att in self.get_data_dict_attributes(): # remove source index for dict soon
if att in file[groupname]:
found_data = True
co_dict[att] = np.array(file[f'{groupname}/{att}'])
co_subdict[att] = np.array(file[f'{groupname}/{att}'])
if found_data:
return co_dict
return co_subdict

def load_one_co_dict(self, groupname, filepath=None):
"""Load data subdict for a single cutout into this Cutouts co_dict. This allows
Expand All @@ -366,7 +373,9 @@ def load_one_co_dict(self, groupname, filepath=None):
filepath = self.get_fullpath()

with h5py.File(filepath, 'r') as file:
self._co_dict[groupname] = self._load_dataset_dict_from_hdf5(file, groupname)
co_subdict = self._load_dataset_dict_from_hdf5(file, groupname)
if co_subdict is not None:
self.co_dict[groupname] = co_subdict
return None

def load(self, filepath=None):
Expand All @@ -385,13 +394,14 @@ def load(self, filepath=None):
if filepath is None:
raise ValueError("Could not find filepath to load")

self._co_dict = {}
self.co_dict = Co_Dict()
self.co_dict.cutouts = self

if os.path.exists(filepath): # WHPR revisit this check... necessary?
if self.format == 'hdf5':
with h5py.File(filepath, 'r') as file:
for groupname in file:
self._co_dict[groupname] = self._load_dataset_dict_from_hdf5(file, groupname)
self.co_dict[groupname] = self._load_dataset_dict_from_hdf5(file, groupname)

def get_upstreams(self, session=None):
"""Get the detections SourceList that was used to make this cutout. """
Expand All @@ -404,6 +414,3 @@ def get_downstreams(self, session=None, siblings=False):

with SmartSession(session) as session:
return session.scalars(sa.select(Measurements).where(Measurements.cutouts_id == self.id)).all()

def _get_inverse_badness(self):
return cutouts_badness_inverse
40 changes: 16 additions & 24 deletions models/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from models.base import Base, SeeChangeBase, SmartSession, AutoIDMixin, SpatiallyIndexed, HasBitFlagBadness
from models.cutouts import Cutouts
from models.enums_and_bitflags import measurements_badness_inverse

from improc.photometry import get_circle

Expand Down Expand Up @@ -43,7 +44,8 @@ class Measurements(Base, AutoIDMixin, SpatiallyIndexed, HasBitFlagBadness):
index_in_sources = sa.Column(
sa.Integer,
nullable=False,
doc="Index of this cutout in the source list (of detections in the difference image). "
doc="Index of the data for this Measurements"
"in the source list (of detections in the difference image). "
)

object_id = sa.Column(
Expand Down Expand Up @@ -410,19 +412,15 @@ def get_data_from_cutouts(self):
"""Populates this object with the cutout data arrays used in
calculations. This allows us to use, for example, self.sub_data
without having to look constantly back into the related Cutouts.
If that is not a concern, all such calls could instead refer back
to the Cutouts data.
"""
# QUESTION: I have chosen here to load the data into the Measurements object
# when needed rather than have to constantly refer back to something like
# self.cutouts.co_dict_noload[self.index_in_sources], although that would
# be a perfectly acceptable way to access the data, and potentially involve
# less wasted memory. Is there any advantage (speed, database usage, etc) to
# doing it this way, or would it be better to use the relationship directly
# and skip populating the data temporarily in this Measurements object?
Importantly, the data for this measurements should have already
been loaded by the Co_Dict class
"""
groupname = f'source_index_{self.index_in_sources}'

if not self.cutouts.co_dict.get(groupname):
raise ValueError(f"No subdict found for {groupname}")

co_data_dict = self.cutouts.co_dict[groupname] # get just the subdict with data for this

for att in Cutouts.get_data_dict_attributes():
Expand Down Expand Up @@ -595,6 +593,9 @@ def get_downstreams(self, session=None, siblings=False):
"""Get the downstreams of this Measurements"""
return []

def _get_inverse_badness(self):
return measurements_badness_inverse

@classmethod
def delete_list(cls, measurements_list, session=None, commit=True):
"""
Expand Down Expand Up @@ -628,21 +629,12 @@ def load_attribute(object, att):
if not hasattr(object, f'_{att}'):
raise AttributeError(f"The object {object} does not have the attribute {att}.")
if getattr(object, f'_{att}') is None:
if object.cutouts.co_dict_noload == {} and object.cutouts.filepath is None:
if len(object.cutouts.co_dict) == 0 and object.cutouts.filepath is None:
return None # objects just now created and not saved cannot lazy load data!

groupname = f'source_index_{object.index_in_sources}'
if groupname not in object.cutouts.co_dict_noload.keys():
# try and load the info for this measurement
object.cutouts.load_one_co_dict(groupname)
if groupname not in object.cutouts.co_dict_noload.keys():
raise ValueError("This measurements not found in Cutouts data dict")
if att not in object.cutouts.co_dict_noload[groupname].keys():
raise ValueError(f"No matching entry in dict for key {att}")
object.get_data_from_cutouts() # this does load ALL 9 data attributes
# into this object, but that should only
# ever happen once, as future calls will
# find the data and just return it
if object.cutouts.co_dict[groupname] is not None: # will check disk as Co_Dict
object.get_data_from_cutouts()

# after data is filled, should be able to just return it
return getattr(object, f'_{att}')
Expand Down
Loading

0 comments on commit 66857ca

Please sign in to comment.