diff --git a/models/cutouts.py b/models/cutouts.py index e9f4c21c..30c8e919 100644 --- a/models/cutouts.py +++ b/models/cutouts.py @@ -4,8 +4,7 @@ from sqlalchemy.ext.hybrid import hybrid_property from models.base import Base, SeeChangeBase, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed -from models.enums_and_bitflags import cutouts_format_dict, cutouts_format_converter - +from models.enums_and_bitflags import CutoutsFormatConverter class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed): @@ -14,23 +13,23 @@ class Cutouts(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed): _format = sa.Column( sa.SMALLINT, nullable=False, - default=cutouts_format_converter('fits'), + default=CutoutsFormatConverter.convert('fits'), doc="Format of the file on disk. Should be fits, hdf5, csv or npy. " "Saved as integer but is converter to string when loaded. " ) @hybrid_property def format(self): - return cutouts_format_converter(self._format) + return CutoutsFormatConverter.convert(self._format) @format.expression def format(cls): # ref: https://stackoverflow.com/a/25272425 - return sa.case(cutouts_format_dict, value=cls._format) + return sa.case(CutoutsFormatConverter.dict, value=cls._format) @format.setter def format(self, value): - self._format = cutouts_format_converter(value) + self._format = CutoutsFormatConverter.convert(value) source_list_id = sa.Column( sa.ForeignKey('source_lists.id', name='cutouts_source_list_id_fkey'), diff --git a/models/enums_and_bitflags.py b/models/enums_and_bitflags.py index 8c4a6f88..5e71e0fb 100644 --- a/models/enums_and_bitflags.py +++ b/models/enums_and_bitflags.py @@ -2,164 +2,169 @@ Here we put all the dictionaries and conversion functions for getting/setting enums and bitflags. """ +from util.classproperty import classproperty -def c(keyword): - """Convert the key to something more compatible. """ - return keyword.lower().replace(' ', '') - - -# This is the master format dictionary, that contains all file types for -# all data models. Each model will get a subset of this dictionary. -file_format_dict = { - 1: 'fits', - 2: 'hdf5', - 3: 'csv', - 4: 'json', - 5: 'yaml', - 6: 'xml', - 7: 'pickle', - 8: 'parquet', - 9: 'npy', - 10: 'npz', - 11: 'avro', - 12: 'netcdf', - 13: 'jpg', - 14: 'png', - 15: 'pdf', -} +class EnumConverter: + """Base class for creating an (effective) enum that is saved to the database as an int. -allowed_image_formats = ['fits', 'hdf5'] -image_format_dict = {k: v for k, v in file_format_dict.items() if v in allowed_image_formats} -image_format_inverse = {c(v): k for k, v in image_format_dict.items()} + This avoids the pain of dealing with Postgres enums and migrations. + To use this: -def image_format_converter(value): - """ - Convert between an image format string (e.g., "fits" or "hdf5") - to the corresponding integer key (e.g., 1 or 2). If given a string, - will return the integer key, and if given a number (float or int) - will return the corresponding string. - String identification is case insensitive and ignores spaces. + 1. Create a subclass of EnumConverter (called here). - If given None, will return None. - """ - if isinstance(value, str): - if c(value) not in image_format_inverse: - raise ValueError(f'Image format must be one of {image_format_inverse.keys()}, not {value}') - return image_format_inverse[c(value)] - elif isinstance(value, (int, float)): - if value not in image_format_dict: - raise ValueError(f'Image format integer key must be one of {image_format_dict.keys()}, not {value}') - return image_format_dict[value] - elif value is None: - return None - else: - raise ValueError(f'Image format must be integer/float key or string value, not {type(value)}') + 2. Define the _dict property of that subclass to have the mapping from integer to string. + 3. If not all of the strings in the _dict are allowed formats for + this class, define a property _allowed_values with a list of + values (strings) that are allowed. (See, for example, + ImageFormatConverter.) If they are all allowed formats, you can + instead just define _allowed_values as None. (See, for example, + ImageTypeConverter.) -allowed_cutout_formats = ['fits', 'hdf5', 'jpg', 'png'] -cutouts_format_dict = {k: v for k, v in file_format_dict.items() if v in allowed_cutout_formats} -cutouts_format_inverse = {c(v): k for k, v in cutouts_format_dict.items()} + 4. Make sure that every class has its own initialized values of + _dict_filtered and _dict_inverse, both initialized to None. + (This is necessary because we're using these two class variables + as mutable variables, so we have to make sure that inheritance + doesn't confuse the different classes with each other.) + 5. In the database model that uses the enum, create fields and properties like: -def cutouts_format_converter(value): - """ - Convert between a cutouts format string (e.g., "fits" or "hdf5") - to the corresponding integer key (e.g., 1 or 2). If given a string, - will return the integer key, and if given a number (float or int) - will return the corresponding string. - String identification is case insensitive and ignores spaces. + _format = sa.Column( sa.SMALLINT, nullable=False, default=.convert('' ) - If given None, will return None. - """ - if isinstance(value, str): - if c(value) not in cutouts_format_inverse: - raise ValueError(f'Cutouts format must be one of {cutouts_format_inverse.keys()}, not {value}') - return cutouts_format_inverse[c(value)] - elif isinstance(value, (int, float)): - if value not in cutouts_format_dict: - raise ValueError(f'Cutouts format integer key must be one of {cutouts_format_dict.keys()}, not {value}') - return cutouts_format_dict[value] - elif value is None: - return None - else: - raise ValueError(f'Cutouts format must be integer/float key or string value, not {type(value)}') + @hybrid_property + def format(self): + return .convert( self._format ) + @format.expression + def format(cls): + return sa.case( .dict, value=cls._format ) -allowed_source_list_formats = ['npy', 'csv', 'hdf5', 'parquet', 'fits'] -source_list_format_dict = {k: v for k, v in file_format_dict.items() if v in allowed_source_list_formats} -source_list_format_inverse = {c(v): k for k, v in source_list_format_dict.items()} + @format.setter + def format( self, value ): + self._format = .convert( value ) + 6. Anywhere in code where you want to convert between the string and + the corresponding integer key (in either direction), just call + .convert( value ) -def source_list_format_converter(value): """ - Convert between a source list format string (e.g., "fits" or "npy") - to the corresponding integer key (e.g., 1 or 9). If given a string, - will return the integer key, and if given a number (float or int) - will return the corresponding string. - String identification is case insensitive and ignores spaces. - If given None, will return None. - """ + _dict = {} + _allowed_values = None + _dict_filtered = None + _dict_inverse = None + + @classmethod + def c( cls, keyword ): + """Convert the key to something more compatible. """ + return keyword.lower().replace(' ', '') + + @classproperty + def dict( cls ): + if cls._dict_filtered is None: + if cls._allowed_values is None: + cls._dict_filtered = cls._dict + else: + cls._dict_filtered = { k: v for k, v in cls._dict.items() if v in cls._allowed_values } + return cls._dict_filtered + + @classproperty + def dict_inverse( cls ): + if cls._dict_inverse is None: + cls._dict_inverse = { cls.c(v): k for k, v in cls._dict.items() } + return cls._dict_inverse + + @classmethod + def convert( cls, value ): + """Convert between a string and corresponding integer key. + + If given a string, will return the integer key. If given an + integer key, will return the corresponding string. String + identification is case-insensitive and ignores spaces. - if isinstance(value, str): - if c(value) not in source_list_format_inverse: - raise ValueError(f'Source list format must be one of {source_list_format_inverse.keys()}, not {value}') - return source_list_format_inverse[c(value)] - elif isinstance(value, (int, float)): - if value not in source_list_format_dict: - raise ValueError(f'Source list format integer key must be one of {source_list_format_dict.keys()}, not {value}') - return source_list_format_dict[value] - elif value is None: - return None - else: - raise ValueError(f'Source list format must be integer/float key or string value, not {type(value)}') - - -image_type_dict = { - 1: 'Sci', - 2: 'ComSci', - 3: 'Diff', - 4: 'ComDiff', - 5: 'Bias', - 6: 'ComBias', - 7: 'Dark', - 8: 'ComDark', - 9: 'DomeFlat', - 10: 'ComDomeFlat', - 11: 'SkyFlat', - 12: 'ComSkyFlat', - 13: 'TwiFlat', - 14: 'ComTwiFlat', -} -image_type_inverse = {c(v): k for k, v in image_type_dict.items()} - - -def image_type_converter(value): - """ - Convert between an image type string (e.g., "Sci" or "Diff") - to the corresponding integer key (e.g., 1 or 3). If given a string, - will return the integer key, and if given a number (float or int) - will return the corresponding string. - String identification is case insensitive, and ignores spaces. + If given None, will return None. - If given None, will return None. - """ - if isinstance(value, str): - if c(value) not in image_type_inverse: - raise ValueError(f'Image type must be one of {image_type_inverse.keys()}, not {value}') - return image_type_inverse[c(value)] - elif isinstance(value, (int, float)): - if value not in image_type_dict: - raise ValueError(f'Image type integer key must be one of {image_type_dict.keys()}, not {value}') - return image_type_dict[value] - elif value is None: - return None - else: - raise ValueError(f'Image type must be integer/float key or string value, not {type(value)}') + """ + if isinstance(value, str): + if cls.c(value) not in cls.dict_inverse: + raise ValueError(f'{cls.__name__} must be one of {cls.dict_inverse.keys()}, not {value}') + return cls.dict_inverse[cls.c(value)] + elif isinstance(value, (int, float)): + if value not in cls.dict: + raise ValueError(f'{cls.__name__} integer key must be one of {cls.dict.keys()}, not {value}') + return cls.dict[value] + elif value is None: + return None + else: + raise ValueError(f'{cls.__name__} must be integer/float key or string value, not {type(value)}') + + +class FormatConverter( EnumConverter ): + # This is the master format dictionary, that contains all file types for + # all data models. Each model will get a subset of this dictionary. + _dict = { + 1: 'fits', + 2: 'hdf5', + 3: 'csv', + 4: 'json', + 5: 'yaml', + 6: 'xml', + 7: 'pickle', + 8: 'parquet', + 9: 'npy', + 10: 'npz', + 11: 'avro', + 12: 'netcdf', + 13: 'jpg', + 14: 'png', + 15: 'pdf', + 16: 'fitsldac', + } + _allowed_values = None + _dict_filtered = None + _dict_inverse = None + +class ImageFormatConverter( FormatConverter ): + _allowed_values = ['fits', 'hdf5'] + _dict_filtered = None + _dict_inverse = None + +class CutoutsFormatConverter( FormatConverter ): + _dict = ImageFormatConverter._dict + _allowed_values = ['fits', 'hdf5', 'jpg', 'png'] + _dict_filtered = None + _dict_inverse = None + +class SourceListFormatConverter( FormatConverter ): + _allowed_values = ['npy', 'csv', 'hdf5', 'parquet', 'fits'] + _dict_filtered = None + _dict_inverse = None + +class ImageTypeConverter( EnumConverter ): + _dict = { + 1: 'Sci', + 2: 'ComSci', + 3: 'Diff', + 4: 'ComDiff', + 5: 'Bias', + 6: 'ComBias', + 7: 'Dark', + 8: 'ComDark', + 9: 'DomeFlat', + 10: 'ComDomeFlat', + 11: 'SkyFlat', + 12: 'ComSkyFlat', + 13: 'TwiFlat', + 14: 'ComTwiFlat', + } + _allowed_values = None + _dict_filtered = None + _dict_inverse = None def bitflag_to_string(value, dictionary): + """ Takes a 64 bit integer bit-flag and converts it to a comma separated string, using the given dictionary. @@ -237,7 +242,7 @@ def string_to_bitflag(value, dictionary): output = 0 for keyword in value.split(','): original_keyword = keyword - keyword = c(keyword) + keyword = EnumConverter.c(keyword) if keyword not in dictionary: raise ValueError(f'Keyword "{original_keyword}" not recognized in dictionary') output += 2 ** dictionary[keyword] @@ -252,7 +257,7 @@ def string_to_bitflag(value, dictionary): 4: 'Bad Subtraction', 5: 'Bright Sky', } -image_badness_inverse = {c(v): k for k, v in image_badness_dict.items()} +image_badness_inverse = {EnumConverter.c(v): k for k, v in image_badness_dict.items()} # these are the ways a Cutouts object is allowed to be bad cutouts_badness_dict = { @@ -263,7 +268,7 @@ def string_to_bitflag(value, dictionary): 25: 'Bad Pixel', 26: 'Bleed Trail', } -cutouts_badness_inverse = {c(v): k for k, v in cutouts_badness_dict.items()} +cutouts_badness_inverse = {EnumConverter.c(v): k for k, v in cutouts_badness_dict.items()} # these are the ways a SourceList object is allowed to be bad source_list_badness_dict = { @@ -272,13 +277,13 @@ def string_to_bitflag(value, dictionary): 43: 'Few Sources', 44: 'Many Sources', } -source_list_badness_inverse = {c(v): k for k, v in source_list_badness_dict.items()} +source_list_badness_inverse = {EnumConverter.c(v): k for k, v in source_list_badness_dict.items()} # join the badness: data_badness_dict = {0: 'Good'} data_badness_dict.update(image_badness_dict) data_badness_dict.update(cutouts_badness_dict) data_badness_dict.update(source_list_badness_dict) -data_badness_inverse = {c(v): k for k, v in data_badness_dict.items()} +data_badness_inverse = {EnumConverter.c(v): k for k, v in data_badness_dict.items()} if 0 in data_badness_inverse: raise ValueError('Cannot have a badness bitflag of zero. This is reserved for good data.') diff --git a/models/exposure.py b/models/exposure.py index 6e83329e..9e2068ea 100644 --- a/models/exposure.py +++ b/models/exposure.py @@ -11,10 +11,8 @@ from models.base import Base, SeeChangeBase, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed, SmartSession from models.instrument import Instrument, guess_instrument, get_instrument_instance from models.enums_and_bitflags import ( - image_format_converter, - image_format_dict, - image_type_converter, - image_type_dict, + ImageFormatConverter, + ImageTypeConverter, image_badness_inverse, data_badness_dict, string_to_bitflag, @@ -128,7 +126,7 @@ class Exposure(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed): _type = sa.Column( sa.SMALLINT, nullable=False, - default=image_type_converter('Sci'), + default=ImageTypeConverter.convert('Sci'), index=True, doc=( "Type of image. One of: Sci, Diff, Bias, Dark, DomeFlat, SkyFlat, TwiFlat, " @@ -140,36 +138,36 @@ class Exposure(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed): @hybrid_property def type(self): - return image_type_converter(self._type) + return ImageTypeConverter.convert(self._type) @type.expression def type(cls): - return sa.case(image_type_dict, value=cls._type) + return sa.case(ImageTypeConverter.dict, value=cls._type) @type.setter def type(self, value): - self._type = image_type_converter(value) + self._type = ImageTypeConverter.convert(value) _format = sa.Column( sa.SMALLINT, nullable=False, - default=image_format_converter('fits'), + default=ImageFormatConverter.convert('fits'), doc="Format of the file on disk. Should be fits or hdf5. " "The value is saved as SMALLINT but translated to a string when read. " ) @hybrid_property def format(self): - return image_format_converter(self._format) + return ImageFormatConverter.convert(self._format) @format.expression def format(cls): # ref: https://stackoverflow.com/a/25272425 - return sa.case(image_format_dict, value=cls._format) + return sa.case(ImageFormatConverter.dict, value=cls._format) @format.setter def format(self, value): - self._format = image_format_converter(value) + self._format = ImageFormatConverter.convert(value) header = sa.Column( JSONB, diff --git a/models/image.py b/models/image.py index 84256594..e8627a6d 100644 --- a/models/image.py +++ b/models/image.py @@ -18,10 +18,8 @@ from models.exposure import Exposure from models.instrument import get_instrument_instance from models.enums_and_bitflags import ( - image_format_converter, - image_format_dict, - image_type_converter, - image_type_dict, + ImageFormatConverter, + ImageTypeConverter, image_badness_inverse, data_badness_dict, string_to_bitflag, @@ -51,23 +49,23 @@ class Image(Base, AutoIDMixin, FileOnDiskMixin, SpatiallyIndexed, FourCorners): _format = sa.Column( sa.SMALLINT, nullable=False, - default=image_format_converter('fits'), + default=ImageFormatConverter.convert('fits'), doc="Format of the file on disk. Should be fits or hdf5. " ) @hybrid_property def format(self): - return image_format_converter(self._format) + return ImageFormatConverter.convert(self._format) @format.inplace.expression @classmethod def format(cls): # ref: https://stackoverflow.com/a/25272425 - return sa.case(image_format_dict, value=cls._format) + return sa.case(ImageFormatConverter.dict, value=cls._format) @format.inplace.setter def format(self, value): - self._format = image_format_converter(value) + self._format = ImageFormatConverter.convert(value) exposure_id = sa.Column( sa.ForeignKey('exposures.id', ondelete='SET NULL', name='images_exposure_id_fkey'), @@ -176,7 +174,7 @@ def is_sub(cls): _type = sa.Column( sa.SMALLINT, nullable=False, - default=image_type_converter('Sci'), + default=ImageTypeConverter.convert('Sci'), index=True, doc=( "Type of image. One of: Sci, Diff, Bias, Dark, DomeFlat, SkyFlat, TwiFlat, " @@ -188,16 +186,16 @@ def is_sub(cls): @hybrid_property def type(self): - return image_type_converter(self._type) + return ImageTypeConverter.convert(self._type) @type.inplace.expression @classmethod def type(cls): - return sa.case(image_type_dict, value=cls._type) + return sa.case(ImageTypeConverter.dict, value=cls._type) @type.inplace.setter def type(self, value): - self._type = image_type_converter(value) + self._type = ImageTypeConverter.convert(value) provenance_id = sa.Column( sa.ForeignKey('provenances.id', ondelete="CASCADE", name='images_provenance_id_fkey'), diff --git a/models/source_list.py b/models/source_list.py index f09a6a6d..d3933ce6 100644 --- a/models/source_list.py +++ b/models/source_list.py @@ -13,8 +13,7 @@ from models.base import Base, AutoIDMixin, FileOnDiskMixin, SeeChangeBase from models.image import Image from models.enums_and_bitflags import ( - source_list_format_dict, - source_list_format_converter, + SourceListFormatConverter, bitflag_to_string, string_to_bitflag, data_badness_dict, @@ -29,23 +28,23 @@ class SourceList(Base, AutoIDMixin, FileOnDiskMixin): _format = sa.Column( sa.SMALLINT, nullable=False, - default=source_list_format_converter('npy'), + default=SourceListFormatConverter.convert('npy'), doc="Format of the file on disk. Should be fits, hdf5, csv or npy. " "Saved as integer but is converter to string when loaded. " ) @hybrid_property def format(self): - return source_list_format_converter(self._format) + return SourceListFormatConverter.convert(self._format) @format.expression def format(cls): # ref: https://stackoverflow.com/a/25272425 - return sa.case(source_list_format_dict, value=cls._format) + return sa.case(SourceListFormatConverter.dict, value=cls._format) @format.setter def format(self, value): - self._format = source_list_format_converter(value) + self._format = SourceListFormatConverter.convert(value) image_id = sa.Column( sa.ForeignKey('images.id', name='source_lists_image_id_fkey'), diff --git a/tests/models/test_enums.py b/tests/models/test_enums.py index 9cadaded..816dca06 100644 --- a/tests/models/test_enums.py +++ b/tests/models/test_enums.py @@ -1,11 +1,75 @@ +import pytest + from models.enums_and_bitflags import ( - file_format_dict, - image_type_dict, + FormatConverter, + ImageFormatConverter, + CutoutsFormatConverter, + SourceListFormatConverter, + ImageTypeConverter, data_badness_dict, ) def test_enums_zero_values(): - assert 0 not in file_format_dict - assert 0 not in image_type_dict - assert data_badness_dict[0] == 'Good' \ No newline at end of file + assert 0 not in FormatConverter.dict + assert 0 not in ImageTypeConverter.dict + assert data_badness_dict[0] == 'Good' + +def test_converter_dict(): + # Probably should test them all, but test just these + # three and trust that if it works, then the inheritance + # is working for all of them. + + assert ImageTypeConverter.dict == { + 1: 'Sci', + 2: 'ComSci', + 3: 'Diff', + 4: 'ComDiff', + 5: 'Bias', + 6: 'ComBias', + 7: 'Dark', + 8: 'ComDark', + 9: 'DomeFlat', + 10: 'ComDomeFlat', + 11: 'SkyFlat', + 12: 'ComSkyFlat', + 13: 'TwiFlat', + 14: 'ComTwiFlat', + } + assert FormatConverter.dict == { + 1: 'fits', + 2: 'hdf5', + 3: 'csv', + 4: 'json', + 5: 'yaml', + 6: 'xml', + 7: 'pickle', + 8: 'parquet', + 9: 'npy', + 10: 'npz', + 11: 'avro', + 12: 'netcdf', + 13: 'jpg', + 14: 'png', + 15: 'pdf', + 16: 'fitsldac', + } + assert ImageFormatConverter.dict == { 1: 'fits', 2: 'hdf5' } + +def test_converter_convert(): + for cls in ( ImageFormatConverter, CutoutsFormatConverter, SourceListFormatConverter, ImageTypeConverter ): + for key, val in cls.dict.items(): + assert cls.convert( key ) == val + assert cls.convert( val ) == key + assert cls.convert( val.lower().replace( ' ', '' ) ) == key + + # Check a few bad ones + + with pytest.raises( ValueError, match='.*must be one of' ): + val = ImageFormatConverter.convert( 'non existent format' ) + + with pytest.raises( ValueError, match='.*integer key must be one of' ): + val = ImageFormatConverter.convert( -1 ) + + with pytest.raises( ValueError, match='.*must be integer/float key' ): + val = ImageFormatConverter.convert( [] ) diff --git a/tests/models/test_image.py b/tests/models/test_image.py index 63801e67..31503d41 100644 --- a/tests/models/test_image.py +++ b/tests/models/test_image.py @@ -244,7 +244,7 @@ def test_image_enum_values(demo_image, provenance_base): assert os.path.exists(data_filename) try: - with pytest.raises(ValueError, match='Image type must be one of .* not foo'): + with pytest.raises(ValueError, match='ImageTypeConverter must be one of .* not foo'): demo_image.type = 'foo' session.add(demo_image) session.commit() @@ -268,7 +268,7 @@ def test_image_enum_values(demo_image, provenance_base): assert demo_image.id not in [i.id for i in images] # check the image format enum works as expected: - with pytest.raises(ValueError, match='Image format must be one of .* not foo'): + with pytest.raises(ValueError, match='ImageFormatConverter must be one of .* not foo'): demo_image.format = 'foo' session.add(demo_image) session.commit() diff --git a/util/classproperty.py b/util/classproperty.py new file mode 100644 index 00000000..bba9cfaa --- /dev/null +++ b/util/classproperty.py @@ -0,0 +1,5 @@ +class classproperty: + def __init__(self, func): + self.fget = func + def __get__(self, instance, owner): + return self.fget(owner)