Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 10, 2023
1 parent fe12765 commit 49185e3
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 124 deletions.
167 changes: 81 additions & 86 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ast
import math

from astropy.io import fits
from astropy.table import Table
from numpy import result_type
from pathlib import Path
import pickle
from yaml import dump, safe_load


Expand Down Expand Up @@ -123,7 +123,19 @@ def set(self, param, value, strict=True):
else:
self._params[param] = value

def set_from_dict(self, d, strict=True):
def validate(self):
"""Check that the configuration has the necessary parameters.
Raises
------
Raises a ``ValueError`` if a parameter is missing.
"""
for p in self._required_params:
if self._params.get(p, None) is None:
raise ValueError(f"Required configuration parameter {p} missing.")

@classmethod
def from_dict(cls, d, strict=True):
"""Sets multiple values from a dictionary.
Parameters
Expand All @@ -138,10 +150,13 @@ def set_from_dict(self, d, strict=True):
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
config = SearchConfiguration()
for key, value in d.items():
self.set(key, value, strict)
config.set(key, value, strict)
return config

def set_from_table(self, t, strict=True):
@classmethod
def from_table(cls, t, strict=True):
"""Sets multiple values from an astropy Table with a single row and
one column for each parameter.
Expand All @@ -161,110 +176,102 @@ def set_from_table(self, t, strict=True):
"""
if len(t) > 1:
raise ValueError(f"More than one row in the configuration table ({len(t)}).")

config = SearchConfiguration()
for key in t.colnames:
# We use a special indicator for serializing certain types (including
# None and dict) to FITS.
if key.startswith("__PICKLED_"):
val = pickle.loads(t[key].value[0])
key = key[10:]
if key.startswith("__NONE__"):
val = None
key = key[8:]
elif key.startswith("__DICT__"):
val = dict(t[key][0])
key = key[8:]
else:
val = t[key][0]

self.set(key, val, strict)

def to_table(self, make_fits_safe=False):
"""Create an astropy table with all the configuration parameters.
Parameter
---------
make_fits_safe : `bool`
Override Nones and dictionaries so we can write to FITS.
config.set(key, val, strict)
return config

Returns
-------
t: `~astropy.table.Table`
The configuration table.
"""
t = Table()
for col in self._params.keys():
val = self._params[col]
t[col] = [val]

# If Table does not understand the type, pickle it.
if make_fits_safe and t[col].dtype == "O":
t.remove_column(col)
t["__PICKLED_" + col] = pickle.dumps(val)

return t

def validate(self):
"""Check that the configuration has the necessary parameters.
Raises
------
Raises a ``ValueError`` if a parameter is missing.
"""
for p in self._required_params:
if self._params.get(p, None) is None:
raise ValueError(f"Required configuration parameter {p} missing.")

def load_from_yaml_file(self, filename, strict=True):
@classmethod
def from_yaml(cls, config, strict=True):
"""Load a configuration from a YAML file.
Parameters
----------
filename : `str`
The filename, including path, of the configuration file.
config : `str` or `_io.TextIOWrapper`
The serialized YAML data.
strict : `bool`
Raise an exception on unknown parameters.
Raises
------
Raises a ``ValueError`` if the configuration file is not found.
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
if not Path(filename).is_file():
raise ValueError(f"Configuration file {filename} not found.")
yaml_params = safe_load(config)
return SearchConfiguration.from_dict(yaml_params, strict)

# Read the user-specified parameters from the file.
file_params = {}
with open(filename, "r") as config:
file_params = safe_load(config)

# Merge in the new values.
self.set_from_dict(file_params, strict)

if strict:
self.validate()

def load_from_fits_file(self, filename, layer=0, strict=True):
@classmethod
def from_hdu(cls, hdu, strict=True):
"""Load a configuration from a FITS extension file.
Parameters
----------
filename : `str`
The filename, including path, of the configuration file.
layer : `int`
The extension number to use.
hdu : `astropy.io.fits.BinTableHDU`
The HDU from which to parse the configuration information.
strict : `bool`
Raise an exception on unknown parameters.
Raises
------
Raises a ``ValueError`` if the configuration file is not found.
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
if not Path(filename).is_file():
raise ValueError(f"Configuration file {filename} not found.")
config = SearchConfiguration()
for column in hdu.data.columns:
key = column.name
val = hdu.data[key][0]

# Read the user-specified parameters from the file.
t = Table.read(filename, hdu=layer)
self.set_from_table(t)
# We use a special indicator for serializing certain types (including
# None and dict) to FITS.
if type(val) is str and val == "__NONE__":
val = None
elif key.startswith("__DICT__"):
val = ast.literal_eval(val)
key = key[8:]

config.set(key, val, strict)
return config

@classmethod
def from_file(cls, filename, extension=0, strict=True):
if filename.endswith("yaml"):
with open(filename) as ff:
return SearchConfiguration.from_yaml(ff.read())
elif ".fits" in filename:
with fits.open(filename) as ff:
return SearchConfiguration.from_hdu(ff[extension])
raise ValueError("Configuration file suffix unrecognized.")

def to_hdu(self):
"""Create a fits HDU with all the configuration parameters.
if strict:
self.validate()
Returns
-------
hdu : `astropy.io.fits.BinTableHDU`
The HDU with the configuration information.
"""
t = Table()
for col in self._params.keys():
val = self._params[col]
if val is None:
t[col] = ["__NONE__"]
elif type(val) is dict:
t["__DICT__" + col] = [str(val)]
else:
t[col] = [val]
return fits.table_to_hdu(t)

def save_to_yaml_file(self, filename, overwrite=False):
"""Save a configuration file with the parameters.
Expand All @@ -282,15 +289,3 @@ def save_to_yaml_file(self, filename, overwrite=False):

with open(filename, "w") as file:
file.write(dump(self._params))

def append_to_fits(self, filename):
"""Append the configuration table as a new extension on a FITS file
(creating a new file if needed).
Parameters
----------
filename : str
The filename, including path, of the configuration file.
"""
t = self.to_table(make_fits_safe=True)
t.write(filename, append=True)
76 changes: 38 additions & 38 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,43 @@ def test_set(self):
# The set should fail when using unknown parameters and strict checking.
self.assertRaises(KeyError, config.set, "My_new_param", 100, strict=True)

def test_set_from_dict(self):
# Everything starts at its default.
config = SearchConfiguration()
self.assertIsNone(config["im_filepath"])
self.assertEqual(config["num_obs"], 10)

def test_from_dict(self):
d = {"im_filepath": "Here2", "num_obs": 5}
config.set_from_dict(d)
config = SearchConfiguration.from_dict(d)
self.assertEqual(config["im_filepath"], "Here2")
self.assertEqual(config["num_obs"], 5)

def test_set_from_table(self):
# Everything starts at its default.
config = SearchConfiguration()
self.assertIsNone(config["im_filepath"])
self.assertEqual(config["num_obs"], 10)
def test_from_hdu(self):
t = Table([["Here3"], [7], ["__NONE__"]], names=("im_filepath", "num_obs", "cluster_type"))
hdu = fits.table_to_hdu(t)

t = Table([["Here3"], [7]], names=("im_filepath", "num_obs"))
config.set_from_table(t)
config = SearchConfiguration.from_hdu(hdu)
self.assertEqual(config["im_filepath"], "Here3")
self.assertEqual(config["num_obs"], 7)
self.assertIsNone(config["cluster_type"])

def test_to_table(self):
def test_to_hdu(self):
# Everything starts at its default.
config = SearchConfiguration()
d = {"im_filepath": "Here2", "num_obs": 5}
config.set_from_dict(d)

t = config.to_table()
self.assertEqual(len(t), 1)
self.assertEqual(t["im_filepath"][0], "Here2")
self.assertEqual(t["num_obs"][0], 5)
d = {
"im_filepath": "Here2",
"num_obs": 5,
"cluster_type": None,
"mask_bits_dict": {"bit1": 1, "bit2": 2},
"do_clustering": False,
"res_filepath": "There",
"ang_arr": [1.0, 2.0, 3.0],
}
config = SearchConfiguration.from_dict(d)
hdu = config.to_hdu()

self.assertEqual(hdu.data["im_filepath"][0], "Here2")
self.assertEqual(hdu.data["num_obs"][0], 5)
self.assertEqual(hdu.data["cluster_type"][0], "__NONE__")
self.assertEqual(hdu.data["__DICT__mask_bits_dict"][0], "{'bit1': 1, 'bit2': 2}")
self.assertEqual(hdu.data["res_filepath"][0], "There")
self.assertEqual(hdu.data["ang_arr"][0][0], 1.0)
self.assertEqual(hdu.data["ang_arr"][0][1], 2.0)
self.assertEqual(hdu.data["ang_arr"][0][2], 3.0)

def test_save_and_load_yaml(self):
config = SearchConfiguration()
Expand All @@ -74,20 +79,19 @@ def test_save_and_load_yaml(self):
config.set("mask_grow", 5)

with tempfile.TemporaryDirectory() as dir_name:
file_path = f"{dir_name}/tmp_config_data.cfg"
file_path = f"{dir_name}/tmp_config_data.yaml"
self.assertFalse(Path(file_path).is_file())

# Unable to load non-existent file.
config2 = SearchConfiguration()
self.assertRaises(ValueError, config2.load_from_yaml_file, file_path)
self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path)

# Correctly saves file.
config.save_to_yaml_file(file_path)
self.assertTrue(Path(file_path).is_file())

# Correctly loads file.
try:
config2.load_from_yaml_file(file_path)
config2 = SearchConfiguration.from_file(file_path)
except ValueError:
self.fail("load_configuration() raised ValueError.")

Expand All @@ -112,21 +116,17 @@ def test_save_and_load_fits(self):
self.assertFalse(Path(file_path).is_file())

# Unable to load non-existent file.
config2 = SearchConfiguration()
self.assertRaises(ValueError, config2.load_from_fits_file, file_path)
self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path)

# Generate measningless data for table 0 and the configuration for table 1.
t0 = Table([[1] * 10, [2] * 10, [3] * 10], names=("A", "B", "C"))
t0.write(file_path)
self.assertTrue(Path(file_path).is_file())

# Append the FITS data to extension=1
config.append_to_fits(file_path)
self.assertTrue(Path(file_path).is_file())
# Generate empty data for the first two tables and config for the third.
hdu0 = fits.PrimaryHDU()
hdu1 = fits.ImageHDU()
hdu_list = fits.HDUList([hdu0, hdu1, config.to_hdu()])
hdu_list.writeto(file_path)

# Correctly loads file.
try:
config2.load_from_fits_file(file_path, layer=2)
config2 = SearchConfiguration.from_file(file_path, extension=2)
except ValueError:
self.fail("load_from_fits_file() raised ValueError.")

Expand Down

0 comments on commit 49185e3

Please sign in to comment.