From 13f712ba8b248aad3c9505369e1df0f7bfa38182 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 10 Oct 2023 11:08:35 -0400 Subject: [PATCH] Address remaining PR comments --- src/kbmod/configuration.py | 25 +++++++++++++++---------- src/kbmod/run_search.py | 11 ++++++----- tests/test_configuration.py | 35 +++++++++++++++++++++++++++++------ 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index efe6df5f..f92967e8 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -245,14 +245,9 @@ def from_hdu(cls, hdu, strict=True): return config @classmethod - def from_file(cls, filename, extension=0, strict=True): - if filename.endswith("yaml"): - with open(filename) as ff: - return SearchConfiguration.from_yaml(ff.read()) - elif ".fits" in filename: - with fits.open(filename) as ff: - return SearchConfiguration.from_hdu(ff[extension]) - raise ValueError("Configuration file suffix unrecognized.") + def from_file(cls, filename, strict=True): + with open(filename) as ff: + return SearchConfiguration.from_yaml(ff.read(), strict) def to_hdu(self): """Create a fits HDU with all the configuration parameters. @@ -273,7 +268,17 @@ def to_hdu(self): t[col] = [val] return fits.table_to_hdu(t) - def save_to_yaml_file(self, filename, overwrite=False): + def to_yaml(self): + """Save a configuration file with the parameters. + + Returns + ------- + result : `str` + The serialized YAML string. + """ + return dump(self._params) + + def to_file(self, filename, overwrite=False): """Save a configuration file with the parameters. Parameters @@ -288,4 +293,4 @@ def save_to_yaml_file(self, filename, overwrite=False): return with open(filename, "w") as file: - file.write(dump(self._params)) + file.write(self.to_yaml()) diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index a874c1d4..14af551a 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -44,15 +44,16 @@ class run_search: """ def __init__(self, input_parameters, config_file=None): - self.config = SearchConfiguration() - # Load parameters from a file. if config_file != None: - self.config.load_from_yaml_file(config_file) + self.config = SearchConfiguration.from_file(config_file) + else: + self.config = SearchConfiguration() # Load any additional parameters (overwriting what is there). if len(input_parameters) > 0: - self.config.set_from_dict(input_parameters) + for key, value in input_parameters.items(): + self.config.set(key, value) # Validate the configuration. self.config.validate() @@ -301,7 +302,7 @@ def run_search(self): config_filename = os.path.join( self.config["res_filepath"], f"config_{self.config['output_suffix']}.yml" ) - self.config.save_to_yaml_file(config_filename, overwrite=True) + self.config.to_file(config_filename, overwrite=True) end = time.time() print("Time taken for patch: ", end - start) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index bc2e9b9e..b1a2a0c5 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -3,6 +3,7 @@ import tempfile import unittest from pathlib import Path +from yaml import safe_load from kbmod.configuration import SearchConfiguration @@ -47,7 +48,6 @@ def test_from_hdu(self): self.assertIsNone(config["cluster_type"]) def test_to_hdu(self): - # Everything starts at its default. d = { "im_filepath": "Here2", "num_obs": 5, @@ -69,6 +69,30 @@ def test_to_hdu(self): self.assertEqual(hdu.data["ang_arr"][0][1], 2.0) self.assertEqual(hdu.data["ang_arr"][0][2], 3.0) + def test_to_yaml(self): + d = { + "im_filepath": "Here2", + "num_obs": 5, + "cluster_type": None, + "mask_bits_dict": {"bit1": 1, "bit2": 2}, + "do_clustering": False, + "res_filepath": "There", + "ang_arr": [1.0, 2.0, 3.0], + } + config = SearchConfiguration.from_dict(d) + yaml_str = config.to_yaml() + + yaml_dict = safe_load(yaml_str) + self.assertEqual(yaml_dict["im_filepath"], "Here2") + self.assertEqual(yaml_dict["num_obs"], 5) + self.assertEqual(yaml_dict["cluster_type"], None) + self.assertEqual(yaml_dict["mask_bits_dict"]["bit1"], 1) + self.assertEqual(yaml_dict["mask_bits_dict"]["bit2"], 2) + self.assertEqual(yaml_dict["res_filepath"], "There") + self.assertEqual(yaml_dict["ang_arr"][0], 1.0) + self.assertEqual(yaml_dict["ang_arr"][1], 2.0) + self.assertEqual(yaml_dict["ang_arr"][2], 3.0) + def test_save_and_load_yaml(self): config = SearchConfiguration() num_defaults = len(config._params) @@ -86,7 +110,7 @@ def test_save_and_load_yaml(self): self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path) # Correctly saves file. - config.save_to_yaml_file(file_path) + config.to_file(file_path) self.assertTrue(Path(file_path).is_file()) # Correctly loads file. @@ -125,10 +149,9 @@ def test_save_and_load_fits(self): hdu_list.writeto(file_path) # Correctly loads file. - try: - config2 = SearchConfiguration.from_file(file_path, extension=2) - except ValueError: - self.fail("load_from_fits_file() raised ValueError.") + config2 = SearchConfiguration() + with fits.open(file_path) as ff: + config2 = SearchConfiguration.from_hdu(ff[2]) self.assertEqual(len(config2._params), num_defaults) self.assertEqual(config2["im_filepath"], "Here2")