diff --git a/src/kbmod/data_interface.py b/src/kbmod/data_interface.py index 6de27cef..416965a8 100644 --- a/src/kbmod/data_interface.py +++ b/src/kbmod/data_interface.py @@ -8,6 +8,7 @@ from kbmod.configuration import SearchConfiguration from kbmod.file_utils import * +from kbmod.work_unit import WorkUnit def load_input_from_individual_files( @@ -137,7 +138,7 @@ def load_input_from_individual_files( def load_input_from_config(config, verbose=False): - """This function loads images and ingests them into an ImageStack. + """This function loads images and ingests them into a WorkUnit. Parameters ---------- @@ -148,14 +149,10 @@ def load_input_from_config(config, verbose=False): Returns ------- - stack : `kbmod.ImageStack` - The stack of images loaded. - wcs_list : `list` - A list of `astropy.wcs.WCS` objects for each image. - visit_times : `list` - A list of MJD times. + result : `kbmod.WorkUnit` + The input data as a ``WorkUnit``. """ - return load_input_from_individual_files( + stack, wcs_list, _ = load_input_from_individual_files( config["im_filepath"], config["time_file"], config["psf_file"], @@ -163,3 +160,56 @@ def load_input_from_config(config, verbose=False): kb.PSF(config["psf_val"]), # Default PSF. verbose=verbose, ) + return WorkUnit(stack, config, None, wcs_list) + + +def load_input_from_file(filename, overrides=None): + """Build a WorkUnit from a single filename which could point to a WorkUnit + or configuration file. + + Parameters + ---------- + filename : `str` + The path and file name of the data to load. + overrides : `dict`, optional + A dictionary of configuration parameters to override. For testing. + + Returns + ------- + result : `kbmod.WorkUnit` + The input data as a ``WorkUnit``. + + Raises + ------ + ``ValueError`` if unable to read the data. + """ + path_var = Path(filename) + if not path_var.is_file(): + raise ValueError(f"File {filename} not found.") + + work = None + + path_suffix = path_var.suffix + if path_suffix == ".yml" or path_suffix == ".yaml": + # Try loading as a WorkUnit first. + with open(filename) as ff: + work = WorkUnit.from_yaml(ff.read(), strict=False) + + # If that load did not work, try loading the file as a configuration + # and then using that to load the data files. + if work is None: + config = SearchConfiguration.from_file(filename, strict=False) + if overrides is not None: + config.set_multiple(overrides) + if config["im_filepath"] is not None: + return load_input_from_config(config) + elif ".fits" in filename: + work = WorkUnit.from_fits(filename) + + # None of the load paths worked. + if work is None: + raise ValueError(f"Could not interprete {filename}.") + + if overrides is not None: + work.config.set_multiple(overrides) + return work diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 42d6cff2..ef4d25e0 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -13,7 +13,7 @@ import kbmod.search as kb from .analysis_utils import PostProcess -from .data_interface import load_input_from_config +from .data_interface import load_input_from_config, load_input_from_file from .configuration import SearchConfiguration from .masking import ( BitVectorMasker, @@ -255,81 +255,69 @@ def run_search(self, config, stack): return keep - def run_search_from_config(self, config): - """Run a KBMOD search from a SearchConfiguration object. + def run_search_from_work_unit(self, work): + """Run a KBMOD search from a WorkUnit object. Parameters ---------- - config : `SearchConfiguration` or `dict` - The configuration object with all the information for the run. + work : `WorkUnit` + The input data and configuration. Returns ------- keep : ResultList The results. """ - if type(config) is dict: - config = SearchConfiguration.from_dict(config) - - # Load the image files. - stack, wcs_list, _ = load_input_from_config(config, verbose=config["debug"]) - - # Compute the suggested search angle from the images. This is a 12 arcsecond - # segment parallel to the ecliptic is seen under from the image origin. - if config["average_angle"] == None: - center_pixel = (stack.get_width() / 2, stack.get_height() / 2) - config.set("average_angle", self._calc_suggested_angle(wcs_list[0], center_pixel)) - - return self.run_search(config, stack) + # Set the average angle if it is not set. + if work.config["average_angle"] is None: + center_pixel = (work.im_stack.get_width() / 2, work.im_stack.get_height() / 2) + if work.get_wcs(0) is not None: + work.config.set("average_angle", self._calc_suggested_angle(work.get_wcs(0), center_pixel)) + else: + print("WARNING: average_angle is unset and no WCS provided. Using 0.0.") + work.config.set("average_angle", 0.0) + + # Run the search. + return self.run_search(work.config, work.im_stack) - def run_search_from_config_file(self, filename, overrides=None): - """Run a KBMOD search from a configuration file. + def run_search_from_config(self, config): + """Run a KBMOD search from a SearchConfiguration object + (or corresponding dictionary). Parameters ---------- - filename : `str` - The name of the configuration file. - overrides : `dict`, optional - A dictionary of configuration parameters to override. + config : `SearchConfiguration` or `dict` + The configuration object with all the information for the run. Returns ------- keep : ResultList The results. """ - config = SearchConfiguration.from_file(filename) - if overrides is not None: - config.set_multiple(overrides) + if type(config) is dict: + config = SearchConfiguration.from_dict(config) - return self.run_search_from_config(config) + # Load the data. + work = load_input_from_config(config) + return self.run_search_from_work_unit(work) - def run_search_from_work_unit_file(self, filename, overrides=None): - """Run a KBMOD search from a WorkUnit file. + def run_search_from_file(self, filename, overrides=None): + """Run a KBMOD search from a configuration or WorkUnit file. Parameters ---------- filename : `str` - The name of the WorkUnit file. + The name of the input file. overrides : `dict`, optional - A dictionary of configuration parameters to override. + A dictionary of configuration parameters to override. For testing. Returns ------- keep : ResultList The results. """ - work = WorkUnit.from_fits(filename) - - if overrides is not None: - work.config.set_multiple(overrides) - - if work.config["average_angle"] == None: - print("WARNING: average_angle is unset. WorkUnit currently uses a default of 0.0") - - # TODO: Support the correct setting of the angle. - work.config.set("average_angle", 0.0) - - return self.run_search(work.config, work.im_stack) + work = load_input_from_file(filename, overrides) + return self.run_search_from_work_unit(work) def _count_known_matches(self, result_list, search): """Look up the known objects that overlap the images and count how many diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 6682bd66..0bda4549 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -241,19 +241,37 @@ def from_dict(cls, workunit_dict): return WorkUnit(im_stack=im_stack, config=config, wcs=global_wcs, per_image_wcs=per_image_wcs) @classmethod - def from_yaml(cls, work_unit): + def from_yaml(cls, work_unit, strict=False): """Load a configuration from a YAML string. Parameters ---------- work_unit : `str` or `_io.TextIOWrapper` The serialized YAML data. + strict : `bool` + Raise an error if the file is not a WorkUnit. + + Returns + ------- + result : `WorkUnit` or `None` + Returns the extracted WorkUnit. If the file did not contain a WorkUnit and + strict=False the function will return None. Raises ------ Raises a ``ValueError`` for any invalid parameters. """ yaml_dict = safe_load(work_unit) + + # Check if this a WorkUnit yaml file by checking it has the required fields. + required_fields = ["config", "height", "num_images", "sci_imgs", "times", "var_imgs", "width"] + for name in required_fields: + if name not in yaml_dict: + if strict: + raise ValueError(f"Missing required field {name}") + else: + return None + return WorkUnit.from_dict(yaml_dict) def to_fits(self, filename, overwrite=False): diff --git a/tests/test_data_interface.py b/tests/test_data_interface.py index 34f4edc4..332c2601 100644 --- a/tests/test_data_interface.py +++ b/tests/test_data_interface.py @@ -1,11 +1,18 @@ +from astropy.wcs import WCS +import os +import tempfile import unittest +from yaml import dump from kbmod.configuration import SearchConfiguration from kbmod.data_interface import ( - load_input_from_individual_files, load_input_from_config, + load_input_from_file, + load_input_from_individual_files, ) +from kbmod.fake_data_creator import FakeDataSet from kbmod.search import * +from kbmod.work_unit import WorkUnit from utils.utils_for_tests import get_absolute_data_path @@ -60,19 +67,86 @@ def test_file_load_config(self): config.set("psf_file", get_absolute_data_path("fake_psfs.dat")), config.set("psf_val", 1.0) - stack, wcs_list, mjds = load_input_from_config(config, verbose=False) - self.assertEqual(stack.img_count(), 4) + worku = load_input_from_config(config, verbose=False) # Check that each image loaded corrected. true_times = [57130.2, 57130.21, 57130.22, 57162.0] psfs_std = [1.0, 1.0, 1.3, 1.0] - for i in range(stack.img_count()): - img = stack.get_single_image(i) + for i in range(worku.im_stack.img_count()): + img = worku.im_stack.get_single_image(i) self.assertEqual(img.get_width(), 64) self.assertEqual(img.get_height(), 64) self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std()) + # Try writing the configuration to a YAML file and loading. + with tempfile.TemporaryDirectory() as dir_name: + yaml_file_path = os.path.join(dir_name, "test_config.yml") + + with self.assertRaises(ValueError): + work_fits = load_input_from_file(yaml_file_path) + + config.to_file(yaml_file_path) + + work_yml = load_input_from_file(yaml_file_path) + self.assertIsNotNone(work_yml) + self.assertEqual(work_yml.im_stack.img_count(), 4) + + def test_file_load_workunit(self): + # Create a fake WCS + fake_wcs = WCS( + { + "WCSAXES": 2, + "CTYPE1": "RA---TAN-SIP", + "CTYPE2": "DEC--TAN-SIP", + "CRVAL1": 200.614997245422, + "CRVAL2": -7.78878863332778, + "CRPIX1": 1033.934327, + "CRPIX2": 2043.548284, + "CTYPE1A": "LINEAR ", + "CTYPE2A": "LINEAR ", + "CUNIT1A": "PIXEL ", + "CUNIT2A": "PIXEL ", + } + ) + fake_config = SearchConfiguration() + fake_data = FakeDataSet(64, 64, 11, obs_per_day=10, use_seed=True) + work = WorkUnit(fake_data.stack, fake_config, fake_wcs, None) + + with tempfile.TemporaryDirectory() as dir_name: + # Save and load as FITS + fits_file_path = os.path.join(dir_name, "test_workunit.fits") + + with self.assertRaises(ValueError): + work_fits = load_input_from_file(fits_file_path) + + work.to_fits(fits_file_path) + + work_fits = load_input_from_file(fits_file_path) + self.assertIsNotNone(work_fits) + self.assertEqual(work_fits.im_stack.img_count(), 11) + + # Save and load as YAML + yaml_file_path = os.path.join(dir_name, "test_workunit.yml") + with open(yaml_file_path, "w") as file: + file.write(work.to_yaml()) + + work_yml = load_input_from_file(yaml_file_path) + self.assertIsNotNone(work_yml) + self.assertEqual(work_yml.im_stack.img_count(), 11) + + def test_file_load_invalid(self): + # Create a YAML file that is neither a configuration nor a WorkUnit. + yaml_str = dump({"Field1": 1, "Field2": False}) + + with tempfile.TemporaryDirectory() as dir_name: + yaml_file_path = os.path.join(dir_name, "test_invalid.yml") + with open(yaml_file_path, "w") as file: + file.write(yaml_str) + + with self.assertRaises(ValueError): + work = load_input_from_file(yaml_file_path) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 200c6de9..16a130bc 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -71,7 +71,7 @@ def test_demo_config_file(self): im_filepath = get_absolute_demo_data_path("demo") config_file = get_absolute_demo_data_path("demo_config.yml") rs = SearchRunner() - keep = rs.run_search_from_config_file( + keep = rs.run_search_from_file( config_file, overrides={"im_filepath": im_filepath}, ) @@ -120,7 +120,7 @@ def test_e2e_work_unit(self): work.to_fits(file_path) rs = SearchRunner() - keep = rs.run_search_from_work_unit_file(file_path) + keep = rs.run_search_from_file(file_path) self.assertGreaterEqual(keep.num_results(), 1)