Skip to content

Commit

Permalink
Make enum converters a class inheritance to avoid repeated convert fu…
Browse files Browse the repository at this point in the history
…nctions
  • Loading branch information
rknop authored Sep 15, 2023
1 parent 16d7af6 commit e17d4c4
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 185 deletions.
11 changes: 5 additions & 6 deletions models/cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from sqlalchemy.ext.hybrid import hybrid_property

from models.base import Base, SeeChangeBase, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed
from models.enums_and_bitflags import cutouts_format_dict, cutouts_format_converter

from models.enums_and_bitflags import CutoutsFormatConverter

class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed):

Expand All @@ -14,23 +13,23 @@ class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed):
_format = sa.Column(
sa.SMALLINT,
nullable=False,
default=cutouts_format_converter('fits'),
default=CutoutsFormatConverter.convert('fits'),
doc="Format of the file on disk. Should be fits, hdf5, csv or npy. "
"Saved as integer but is converter to string when loaded. "
)

@hybrid_property
def format(self):
return cutouts_format_converter(self._format)
return CutoutsFormatConverter.convert(self._format)

@format.expression
def format(cls):
# ref: https://stackoverflow.com/a/25272425
return sa.case(cutouts_format_dict, value=cls._format)
return sa.case(CutoutsFormatConverter.dict, value=cls._format)

@format.setter
def format(self, value):
self._format = cutouts_format_converter(value)
self._format = CutoutsFormatConverter.convert(value)

source_list_id = sa.Column(
sa.ForeignKey('source_lists.id', name='cutouts_source_list_id_fkey'),
Expand Down
289 changes: 147 additions & 142 deletions models/enums_and_bitflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,164 +2,169 @@
Here we put all the dictionaries and conversion functions for getting/setting enums and bitflags.
"""

from util.classproperty import classproperty

def c(keyword):
"""Convert the key to something more compatible. """
return keyword.lower().replace(' ', '')


# This is the master format dictionary, that contains all file types for
# all data models. Each model will get a subset of this dictionary.
file_format_dict = {
1: 'fits',
2: 'hdf5',
3: 'csv',
4: 'json',
5: 'yaml',
6: 'xml',
7: 'pickle',
8: 'parquet',
9: 'npy',
10: 'npz',
11: 'avro',
12: 'netcdf',
13: 'jpg',
14: 'png',
15: 'pdf',
}
class EnumConverter:
"""Base class for creating an (effective) enum that is saved to the database as an int.
allowed_image_formats = ['fits', 'hdf5']
image_format_dict = {k: v for k, v in file_format_dict.items() if v in allowed_image_formats}
image_format_inverse = {c(v): k for k, v in image_format_dict.items()}
This avoids the pain of dealing with Postgres enums and migrations.
To use this:
def image_format_converter(value):
"""
Convert between an image format string (e.g., "fits" or "hdf5")
to the corresponding integer key (e.g., 1 or 2). If given a string,
will return the integer key, and if given a number (float or int)
will return the corresponding string.
String identification is case insensitive and ignores spaces.
1. Create a subclass of EnumConverter (called <class> here).
If given None, will return None.
"""
if isinstance(value, str):
if c(value) not in image_format_inverse:
raise ValueError(f'Image format must be one of {image_format_inverse.keys()}, not {value}')
return image_format_inverse[c(value)]
elif isinstance(value, (int, float)):
if value not in image_format_dict:
raise ValueError(f'Image format integer key must be one of {image_format_dict.keys()}, not {value}')
return image_format_dict[value]
elif value is None:
return None
else:
raise ValueError(f'Image format must be integer/float key or string value, not {type(value)}')
2. Define the _dict property of that subclass to have the mapping from integer to string.
3. If not all of the strings in the _dict are allowed formats for
this class, define a property _allowed_values with a list of
values (strings) that are allowed. (See, for example,
ImageFormatConverter.) If they are all allowed formats, you can
instead just define _allowed_values as None. (See, for example,
ImageTypeConverter.)
allowed_cutout_formats = ['fits', 'hdf5', 'jpg', 'png']
cutouts_format_dict = {k: v for k, v in file_format_dict.items() if v in allowed_cutout_formats}
cutouts_format_inverse = {c(v): k for k, v in cutouts_format_dict.items()}
4. Make sure that every class has its own initialized values of
_dict_filtered and _dict_inverse, both initialized to None.
(This is necessary because we're using these two class variables
as mutable variables, so we have to make sure that inheritance
doesn't confuse the different classes with each other.)
5. In the database model that uses the enum, create fields and properties like:
def cutouts_format_converter(value):
"""
Convert between a cutouts format string (e.g., "fits" or "hdf5")
to the corresponding integer key (e.g., 1 or 2). If given a string,
will return the integer key, and if given a number (float or int)
will return the corresponding string.
String identification is case insensitive and ignores spaces.
_format = sa.Column( sa.SMALLINT, nullable=False, default=<class>.convert('<default_value>' )
If given None, will return None.
"""
if isinstance(value, str):
if c(value) not in cutouts_format_inverse:
raise ValueError(f'Cutouts format must be one of {cutouts_format_inverse.keys()}, not {value}')
return cutouts_format_inverse[c(value)]
elif isinstance(value, (int, float)):
if value not in cutouts_format_dict:
raise ValueError(f'Cutouts format integer key must be one of {cutouts_format_dict.keys()}, not {value}')
return cutouts_format_dict[value]
elif value is None:
return None
else:
raise ValueError(f'Cutouts format must be integer/float key or string value, not {type(value)}')
@hybrid_property
def format(self):
return <class>.convert( self._format )
@format.expression
def format(cls):
return sa.case( <class>.dict, value=cls._format )
allowed_source_list_formats = ['npy', 'csv', 'hdf5', 'parquet', 'fits']
source_list_format_dict = {k: v for k, v in file_format_dict.items() if v in allowed_source_list_formats}
source_list_format_inverse = {c(v): k for k, v in source_list_format_dict.items()}
@format.setter
def format( self, value ):
self._format = <class>.convert( value )
6. Anywhere in code where you want to convert between the string and
the corresponding integer key (in either direction), just call
<class>.convert( value )
def source_list_format_converter(value):
"""
Convert between a source list format string (e.g., "fits" or "npy")
to the corresponding integer key (e.g., 1 or 9). If given a string,
will return the integer key, and if given a number (float or int)
will return the corresponding string.
String identification is case insensitive and ignores spaces.

If given None, will return None.
"""
_dict = {}
_allowed_values = None
_dict_filtered = None
_dict_inverse = None

@classmethod
def c( cls, keyword ):
"""Convert the key to something more compatible. """
return keyword.lower().replace(' ', '')

@classproperty
def dict( cls ):
if cls._dict_filtered is None:
if cls._allowed_values is None:
cls._dict_filtered = cls._dict
else:
cls._dict_filtered = { k: v for k, v in cls._dict.items() if v in cls._allowed_values }
return cls._dict_filtered

@classproperty
def dict_inverse( cls ):
if cls._dict_inverse is None:
cls._dict_inverse = { cls.c(v): k for k, v in cls._dict.items() }
return cls._dict_inverse

@classmethod
def convert( cls, value ):
"""Convert between a string and corresponding integer key.
If given a string, will return the integer key. If given an
integer key, will return the corresponding string. String
identification is case-insensitive and ignores spaces.
if isinstance(value, str):
if c(value) not in source_list_format_inverse:
raise ValueError(f'Source list format must be one of {source_list_format_inverse.keys()}, not {value}')
return source_list_format_inverse[c(value)]
elif isinstance(value, (int, float)):
if value not in source_list_format_dict:
raise ValueError(f'Source list format integer key must be one of {source_list_format_dict.keys()}, not {value}')
return source_list_format_dict[value]
elif value is None:
return None
else:
raise ValueError(f'Source list format must be integer/float key or string value, not {type(value)}')


image_type_dict = {
1: 'Sci',
2: 'ComSci',
3: 'Diff',
4: 'ComDiff',
5: 'Bias',
6: 'ComBias',
7: 'Dark',
8: 'ComDark',
9: 'DomeFlat',
10: 'ComDomeFlat',
11: 'SkyFlat',
12: 'ComSkyFlat',
13: 'TwiFlat',
14: 'ComTwiFlat',
}
image_type_inverse = {c(v): k for k, v in image_type_dict.items()}


def image_type_converter(value):
"""
Convert between an image type string (e.g., "Sci" or "Diff")
to the corresponding integer key (e.g., 1 or 3). If given a string,
will return the integer key, and if given a number (float or int)
will return the corresponding string.
String identification is case insensitive, and ignores spaces.
If given None, will return None.
If given None, will return None.
"""
if isinstance(value, str):
if c(value) not in image_type_inverse:
raise ValueError(f'Image type must be one of {image_type_inverse.keys()}, not {value}')
return image_type_inverse[c(value)]
elif isinstance(value, (int, float)):
if value not in image_type_dict:
raise ValueError(f'Image type integer key must be one of {image_type_dict.keys()}, not {value}')
return image_type_dict[value]
elif value is None:
return None
else:
raise ValueError(f'Image type must be integer/float key or string value, not {type(value)}')
"""
if isinstance(value, str):
if cls.c(value) not in cls.dict_inverse:
raise ValueError(f'{cls.__name__} must be one of {cls.dict_inverse.keys()}, not {value}')
return cls.dict_inverse[cls.c(value)]
elif isinstance(value, (int, float)):
if value not in cls.dict:
raise ValueError(f'{cls.__name__} integer key must be one of {cls.dict.keys()}, not {value}')
return cls.dict[value]
elif value is None:
return None
else:
raise ValueError(f'{cls.__name__} must be integer/float key or string value, not {type(value)}')


class FormatConverter( EnumConverter ):
# This is the master format dictionary, that contains all file types for
# all data models. Each model will get a subset of this dictionary.
_dict = {
1: 'fits',
2: 'hdf5',
3: 'csv',
4: 'json',
5: 'yaml',
6: 'xml',
7: 'pickle',
8: 'parquet',
9: 'npy',
10: 'npz',
11: 'avro',
12: 'netcdf',
13: 'jpg',
14: 'png',
15: 'pdf',
16: 'fitsldac',
}
_allowed_values = None
_dict_filtered = None
_dict_inverse = None

class ImageFormatConverter( FormatConverter ):
_allowed_values = ['fits', 'hdf5']
_dict_filtered = None
_dict_inverse = None

class CutoutsFormatConverter( FormatConverter ):
_dict = ImageFormatConverter._dict
_allowed_values = ['fits', 'hdf5', 'jpg', 'png']
_dict_filtered = None
_dict_inverse = None

class SourceListFormatConverter( FormatConverter ):
_allowed_values = ['npy', 'csv', 'hdf5', 'parquet', 'fits']
_dict_filtered = None
_dict_inverse = None

class ImageTypeConverter( EnumConverter ):
_dict = {
1: 'Sci',
2: 'ComSci',
3: 'Diff',
4: 'ComDiff',
5: 'Bias',
6: 'ComBias',
7: 'Dark',
8: 'ComDark',
9: 'DomeFlat',
10: 'ComDomeFlat',
11: 'SkyFlat',
12: 'ComSkyFlat',
13: 'TwiFlat',
14: 'ComTwiFlat',
}
_allowed_values = None
_dict_filtered = None
_dict_inverse = None


def bitflag_to_string(value, dictionary):

"""
Takes a 64 bit integer bit-flag and converts it to a comma separated string,
using the given dictionary.
Expand Down Expand Up @@ -237,7 +242,7 @@ def string_to_bitflag(value, dictionary):
output = 0
for keyword in value.split(','):
original_keyword = keyword
keyword = c(keyword)
keyword = EnumConverter.c(keyword)
if keyword not in dictionary:
raise ValueError(f'Keyword "{original_keyword}" not recognized in dictionary')
output += 2 ** dictionary[keyword]
Expand All @@ -252,7 +257,7 @@ def string_to_bitflag(value, dictionary):
4: 'Bad Subtraction',
5: 'Bright Sky',
}
image_badness_inverse = {c(v): k for k, v in image_badness_dict.items()}
image_badness_inverse = {EnumConverter.c(v): k for k, v in image_badness_dict.items()}

# these are the ways a Cutouts object is allowed to be bad
cutouts_badness_dict = {
Expand All @@ -263,7 +268,7 @@ def string_to_bitflag(value, dictionary):
25: 'Bad Pixel',
26: 'Bleed Trail',
}
cutouts_badness_inverse = {c(v): k for k, v in cutouts_badness_dict.items()}
cutouts_badness_inverse = {EnumConverter.c(v): k for k, v in cutouts_badness_dict.items()}

# these are the ways a SourceList object is allowed to be bad
source_list_badness_dict = {
Expand All @@ -272,13 +277,13 @@ def string_to_bitflag(value, dictionary):
43: 'Few Sources',
44: 'Many Sources',
}
source_list_badness_inverse = {c(v): k for k, v in source_list_badness_dict.items()}
source_list_badness_inverse = {EnumConverter.c(v): k for k, v in source_list_badness_dict.items()}

# join the badness:
data_badness_dict = {0: 'Good'}
data_badness_dict.update(image_badness_dict)
data_badness_dict.update(cutouts_badness_dict)
data_badness_dict.update(source_list_badness_dict)
data_badness_inverse = {c(v): k for k, v in data_badness_dict.items()}
data_badness_inverse = {EnumConverter.c(v): k for k, v in data_badness_dict.items()}
if 0 in data_badness_inverse:
raise ValueError('Cannot have a badness bitflag of zero. This is reserved for good data.')
Loading

0 comments on commit e17d4c4

Please sign in to comment.