Skip to content

Commit

Permalink
Don't parse our own types, let yaml do it for us.
Browse files Browse the repository at this point in the history
  • Loading branch information
DinoBektesevic committed Oct 10, 2023
1 parent 13f712b commit 86c648e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 47 deletions.
47 changes: 9 additions & 38 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astropy.table import Table
from numpy import result_type
from pathlib import Path
import yaml
from yaml import dump, safe_load


Expand Down Expand Up @@ -177,21 +178,10 @@ def from_table(cls, t, strict=True):
if len(t) > 1:
raise ValueError(f"More than one row in the configuration table ({len(t)}).")

config = SearchConfiguration()
for key in t.colnames:
# We use a special indicator for serializing certain types (including
# None and dict) to FITS.
if key.startswith("__NONE__"):
val = None
key = key[8:]
elif key.startswith("__DICT__"):
val = dict(t[key][0])
key = key[8:]
else:
val = t[key][0]
# guaranteed to only have 1 element due to check above
params = {col.name: safe_load(col.value[0]) for col in t.values()}
return SearchConfiguration.from_dict(params)

config.set(key, val, strict)
return config

@classmethod
def from_yaml(cls, config, strict=True):
Expand Down Expand Up @@ -228,21 +218,8 @@ def from_hdu(cls, hdu, strict=True):
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
config = SearchConfiguration()
for column in hdu.data.columns:
key = column.name
val = hdu.data[key][0]

# We use a special indicator for serializing certain types (including
# None and dict) to FITS.
if type(val) is str and val == "__NONE__":
val = None
elif key.startswith("__DICT__"):
val = ast.literal_eval(val)
key = key[8:]

config.set(key, val, strict)
return config
t = Table(hdu.data)
return SearchConfiguration.from_table(t)

@classmethod
def from_file(cls, filename, strict=True):
Expand All @@ -257,15 +234,9 @@ def to_hdu(self):
hdu : `astropy.io.fits.BinTableHDU`
The HDU with the configuration information.
"""
t = Table()
for col in self._params.keys():
val = self._params[col]
if val is None:
t[col] = ["__NONE__"]
elif type(val) is dict:
t["__DICT__" + col] = [str(val)]
else:
t[col] = [val]
serialized_dict = {key: dump(val, default_flow_style=True)
for key, val in self._params.items()}
t = Table(rows=[serialized_dict, ])
return fits.table_to_hdu(t)

def to_yaml(self):
Expand Down
18 changes: 9 additions & 9 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def test_from_dict(self):
self.assertEqual(config["num_obs"], 5)

def test_from_hdu(self):
t = Table([["Here3"], [7], ["__NONE__"]], names=("im_filepath", "num_obs", "cluster_type"))
t = Table([["Here3"], ["7"], ["null"], ["[1, 2]", ]],
names=("im_filepath", "num_obs", "cluster_type", "ang_arr"))
hdu = fits.table_to_hdu(t)

config = SearchConfiguration.from_hdu(hdu)
self.assertEqual(config["im_filepath"], "Here3")
self.assertEqual(config["num_obs"], 7)
self.assertEqual(config["ang_arr"], [1, 2])
self.assertIsNone(config["cluster_type"])

def test_to_hdu(self):
Expand All @@ -60,14 +62,12 @@ def test_to_hdu(self):
config = SearchConfiguration.from_dict(d)
hdu = config.to_hdu()

self.assertEqual(hdu.data["im_filepath"][0], "Here2")
self.assertEqual(hdu.data["num_obs"][0], 5)
self.assertEqual(hdu.data["cluster_type"][0], "__NONE__")
self.assertEqual(hdu.data["__DICT__mask_bits_dict"][0], "{'bit1': 1, 'bit2': 2}")
self.assertEqual(hdu.data["res_filepath"][0], "There")
self.assertEqual(hdu.data["ang_arr"][0][0], 1.0)
self.assertEqual(hdu.data["ang_arr"][0][1], 2.0)
self.assertEqual(hdu.data["ang_arr"][0][2], 3.0)
self.assertEqual(hdu.data["im_filepath"][0], "Here2\n...")
self.assertEqual(hdu.data["num_obs"][0], "5\n...")
self.assertEqual(hdu.data["cluster_type"][0], "null\n...")
self.assertEqual(hdu.data["mask_bits_dict"][0], "{bit1: 1, bit2: 2}")
self.assertEqual(hdu.data["res_filepath"][0], "There\n...")
self.assertEqual(hdu.data["ang_arr"][0], "[1.0, 2.0, 3.0]")

def test_to_yaml(self):
d = {
Expand Down

0 comments on commit 86c648e

Please sign in to comment.