Skip to content

Commit

Permalink
Address remaining PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 10, 2023
1 parent 49185e3 commit 13f712b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 21 deletions.
25 changes: 15 additions & 10 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,9 @@ def from_hdu(cls, hdu, strict=True):
return config

@classmethod
def from_file(cls, filename, extension=0, strict=True):
if filename.endswith("yaml"):
with open(filename) as ff:
return SearchConfiguration.from_yaml(ff.read())
elif ".fits" in filename:
with fits.open(filename) as ff:
return SearchConfiguration.from_hdu(ff[extension])
raise ValueError("Configuration file suffix unrecognized.")
def from_file(cls, filename, strict=True):
with open(filename) as ff:
return SearchConfiguration.from_yaml(ff.read(), strict)

def to_hdu(self):
"""Create a fits HDU with all the configuration parameters.
Expand All @@ -273,7 +268,17 @@ def to_hdu(self):
t[col] = [val]
return fits.table_to_hdu(t)

def save_to_yaml_file(self, filename, overwrite=False):
def to_yaml(self):
"""Save a configuration file with the parameters.
Returns
-------
result : `str`
The serialized YAML string.
"""
return dump(self._params)

def to_file(self, filename, overwrite=False):
"""Save a configuration file with the parameters.
Parameters
Expand All @@ -288,4 +293,4 @@ def save_to_yaml_file(self, filename, overwrite=False):
return

with open(filename, "w") as file:
file.write(dump(self._params))
file.write(self.to_yaml())
11 changes: 6 additions & 5 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ class run_search:
"""

def __init__(self, input_parameters, config_file=None):
self.config = SearchConfiguration()

# Load parameters from a file.
if config_file != None:
self.config.load_from_yaml_file(config_file)
self.config = SearchConfiguration.from_file(config_file)
else:
self.config = SearchConfiguration()

# Load any additional parameters (overwriting what is there).
if len(input_parameters) > 0:
self.config.set_from_dict(input_parameters)
for key, value in input_parameters.items():
self.config.set(key, value)

# Validate the configuration.
self.config.validate()
Expand Down Expand Up @@ -301,7 +302,7 @@ def run_search(self):
config_filename = os.path.join(
self.config["res_filepath"], f"config_{self.config['output_suffix']}.yml"
)
self.config.save_to_yaml_file(config_filename, overwrite=True)
self.config.to_file(config_filename, overwrite=True)

end = time.time()
print("Time taken for patch: ", end - start)
Expand Down
35 changes: 29 additions & 6 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import unittest
from pathlib import Path
from yaml import safe_load

from kbmod.configuration import SearchConfiguration

Expand Down Expand Up @@ -47,7 +48,6 @@ def test_from_hdu(self):
self.assertIsNone(config["cluster_type"])

def test_to_hdu(self):
# Everything starts at its default.
d = {
"im_filepath": "Here2",
"num_obs": 5,
Expand All @@ -69,6 +69,30 @@ def test_to_hdu(self):
self.assertEqual(hdu.data["ang_arr"][0][1], 2.0)
self.assertEqual(hdu.data["ang_arr"][0][2], 3.0)

def test_to_yaml(self):
d = {
"im_filepath": "Here2",
"num_obs": 5,
"cluster_type": None,
"mask_bits_dict": {"bit1": 1, "bit2": 2},
"do_clustering": False,
"res_filepath": "There",
"ang_arr": [1.0, 2.0, 3.0],
}
config = SearchConfiguration.from_dict(d)
yaml_str = config.to_yaml()

yaml_dict = safe_load(yaml_str)
self.assertEqual(yaml_dict["im_filepath"], "Here2")
self.assertEqual(yaml_dict["num_obs"], 5)
self.assertEqual(yaml_dict["cluster_type"], None)
self.assertEqual(yaml_dict["mask_bits_dict"]["bit1"], 1)
self.assertEqual(yaml_dict["mask_bits_dict"]["bit2"], 2)
self.assertEqual(yaml_dict["res_filepath"], "There")
self.assertEqual(yaml_dict["ang_arr"][0], 1.0)
self.assertEqual(yaml_dict["ang_arr"][1], 2.0)
self.assertEqual(yaml_dict["ang_arr"][2], 3.0)

def test_save_and_load_yaml(self):
config = SearchConfiguration()
num_defaults = len(config._params)
Expand All @@ -86,7 +110,7 @@ def test_save_and_load_yaml(self):
self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path)

# Correctly saves file.
config.save_to_yaml_file(file_path)
config.to_file(file_path)
self.assertTrue(Path(file_path).is_file())

# Correctly loads file.
Expand Down Expand Up @@ -125,10 +149,9 @@ def test_save_and_load_fits(self):
hdu_list.writeto(file_path)

# Correctly loads file.
try:
config2 = SearchConfiguration.from_file(file_path, extension=2)
except ValueError:
self.fail("load_from_fits_file() raised ValueError.")
config2 = SearchConfiguration()
with fits.open(file_path) as ff:
config2 = SearchConfiguration.from_hdu(ff[2])

self.assertEqual(len(config2._params), num_defaults)
self.assertEqual(config2["im_filepath"], "Here2")
Expand Down

0 comments on commit 13f712b

Please sign in to comment.