diff --git a/src/kbmod/analysis/wcs_utils.py b/src/kbmod/wcs_utils.py similarity index 81% rename from src/kbmod/analysis/wcs_utils.py rename to src/kbmod/wcs_utils.py index 61c463f9..0a59d7ab 100644 --- a/src/kbmod/analysis/wcs_utils.py +++ b/src/kbmod/wcs_utils.py @@ -1,3 +1,5 @@ +"""A collection of utility functions for working with WCS in KBMOD.""" + import astropy.coordinates import astropy.units import astropy.wcs @@ -276,3 +278,99 @@ def _calc_actual_image_fov(wcs, ref_pixel, image_size): ) refsep2 = [skyvallist[0].separation(skyvallist[1]), skyvallist[2].separation(skyvallist[3])] return refsep2 + + +def extract_wcs_from_hdu_header(header): + """Read an WCS from the an HDU header and do basic validity checking. + + Parameters + ---------- + header : `astropy.io.fits.Header` + The header from which to read the data. + + Returns + -------- + curr_wcs : `astropy.wcs.WCS` + The WCS or None if it does not exist. + """ + # Check that we have (at minimum) the CRVAL and CRPIX keywords. + # These are necessary (but not sufficient) requirements for the WCS. + if "CRVAL1" not in header or "CRVAL2" not in header: + return None + if "CRPIX1" not in header or "CRPIX2" not in header: + return None + + curr_wcs = astropy.wcs.WCS(header) + if curr_wcs is None: + return None + if curr_wcs.naxis != 2: + return None + + return curr_wcs + + +def wcs_from_dict(data): + """Extract a WCS from a fictionary of the HDU header keys/values. + Performs very basic validity checking. + + Parameters + ---------- + data : `dict` + A dictionary containing the WCS header information. + + Returns + ------- + wcs : `astropy.wcs.WCS` + The WCS to convert. + """ + # Check that we have (at minimum) the CRVAL and CRPIX keywords. + # These are necessary (but not sufficient) requirements for the WCS. + if "CRVAL1" not in data or "CRVAL2" not in data: + return None + if "CRPIX1" not in data or "CRPIX2" not in data: + return None + + curr_wcs = astropy.wcs.WCS(data) + if curr_wcs is None: + return None + if curr_wcs.naxis != 2: + return None + + return curr_wcs + + +def append_wcs_to_hdu_header(wcs, header): + """Append the WCS fields to an existing HDU header. + + Parameters + ---------- + wcs : `astropy.wcs.WCS` + The WCS to use. + header : `astropy.io.fits.Header` + The header to which to append the data. + """ + if wcs is not None: + wcs_header = wcs.to_header() + for key in wcs_header: + header[key] = wcs_header[key] + + +def wcs_to_dict(wcs): + """Convert a WCS to a dictionary (via a FITS header). + + Parameters + ---------- + wcs : `astropy.wcs.WCS` + The WCS to convert. + + Returns + ------- + result : `dict` + A dictionary containing the WCS header information. + """ + result = {} + if wcs is not None: + wcs_header = wcs.to_header() + for key in wcs_header: + result[key] = wcs_header[key] + return result diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index d7db9b4d..6682bd66 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -11,6 +11,12 @@ from kbmod.configuration import SearchConfiguration from kbmod.search import ImageStack, LayeredImage, PSF, RawImage +from kbmod.wcs_utils import ( + append_wcs_to_hdu_header, + extract_wcs_from_hdu_header, + wcs_from_dict, + wcs_to_dict, +) class WorkUnit: @@ -25,10 +31,11 @@ class WorkUnit: config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. wcs : `astropy.wcs.WCS` - A gloabl WCS for all images in the WorkUnit. + A global WCS for all images in the WorkUnit. Only exists + if all images have been projected to same pixel space. per_image_wcs : `list` A list with one WCS for each image in the WorkUnit. Used for when - the images have not been standardized to the same pixel space. + the images have *not* been standardized to the same pixel space. """ def __init__(self, im_stack=None, config=None, wcs=None, per_image_wcs=None): @@ -43,6 +50,38 @@ def __init__(self, im_stack=None, config=None, wcs=None, per_image_wcs=None): raise ValueError("Incorrect number of WCS provided.") self.per_image_wcs = per_image_wcs + def __len__(self): + """Returns the size of the WorkUnit in number of images.""" + return self.im_stack.img_count() + + def get_wcs(self, img_num): + """Return the WCS for the a given image. Alway prioritizes + a global WCS if one exits. + + Parameters + ---------- + img_num : `int` + The number of the image. + + Returns + ------- + wcs : `astropy.wcs.WCS` + The image's WCS if one exists. Otherwise None. + + Raises + ------ + IndexError if an invalid index is given. + """ + if img_num < 0 or img_num >= self.im_stack.img_count(): + raise IndexError(f"Invalid image number {img_num}") + + if self.wcs is not None: + if self.per_image_wcs[img_num] is not None: + warnings.warn("Both a global and per-image WCS given. Using global WCS.", Warning) + return self.wcs + + return self.per_image_wcs[img_num] + @classmethod def from_fits(cls, filename): """Create a WorkUnit from a single FITS file. @@ -85,7 +124,7 @@ def from_fits(cls, filename): # since the primary header does not have an image. with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) - global_wcs = extract_wcs(hdul[0]) + global_wcs = extract_wcs_from_hdu_header(hdul[0].header) # Read the size and order information from the primary header. num_images = hdul[0].header["NUMIMG"] @@ -99,7 +138,7 @@ def from_fits(cls, filename): per_image_wcs = [] for i in range(num_images): # Extract the per-image WCS if one exists. - per_image_wcs.append(extract_wcs(hdul[f"SCI_{i}"])) + per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"SCI_{i}"].header)) # Read in science, variance, and mask layers. sci = hdu_to_raw_image(hdul[f"SCI_{i}"]) @@ -146,7 +185,17 @@ def from_dict(cls, workunit_dict): else: raise ValueError("Unrecognized type for WorkUnit config parameter.") + # Load the global WCS if one exists. + if "wcs" in workunit_dict: + if type(workunit_dict["wcs"]) is dict: + global_wcs = wcs_from_dict(workunit_dict["wcs"]) + else: + global_wcs = workunit_dict["wcs"] + else: + global_wcs = None + imgs = [] + per_image_wcs = [] for i in range(num_images): obs_time = workunit_dict["times"][i] @@ -182,8 +231,14 @@ def from_dict(cls, workunit_dict): imgs.append(LayeredImage(sci_img, var_img, msk_img, p)) + # Read a per_image_wcs if one exists. + current_wcs = workunit_dict["per_image_wcs"][i] + if type(current_wcs) is dict: + current_wcs = wcs_from_dict(current_wcs) + per_image_wcs.append(current_wcs) + im_stack = ImageStack(imgs) - return WorkUnit(im_stack=im_stack, config=config) + return WorkUnit(im_stack=im_stack, config=config, wcs=global_wcs, per_image_wcs=per_image_wcs) @classmethod def from_yaml(cls, work_unit): @@ -231,9 +286,7 @@ def to_fits(self, filename, overwrite=False): # If the global WCS exists, append the corresponding keys. if self.wcs is not None: - wcs_header = self.wcs.to_header() - for key in wcs_header: - pri.header[key] = wcs_header[key] + append_wcs_to_hdu_header(self.wcs, pri.header) hdul.append(pri) @@ -285,12 +338,14 @@ def to_yaml(self): "width": self.im_stack.get_width(), "height": self.im_stack.get_height(), "config": self.config._params, + "wcs": wcs_to_dict(self.wcs), # Per image data "times": [], "sci_imgs": [], "var_imgs": [], "msk_imgs": [], "psfs": [], + "per_image_wcs": [], } # Fill in the per-image data. @@ -306,36 +361,9 @@ def to_yaml(self): psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim())) workunit_dict["psfs"].append(psf_array.tolist()) - return dump(workunit_dict) - + workunit_dict["per_image_wcs"].append(wcs_to_dict(self.per_image_wcs[i])) -def extract_wcs(hdu): - """Read an WCS from the header and does basic validity checking. - - Parameters - ---------- - hdu : An astropy HDU (Image or Primary) - The extension - - Returns - -------- - curr_wcs : `astropy.wcs.WCS` - The WCS or None if it does not exist. - """ - # Check that we have (at minimum) the CRVAL and CRPIX keywords. - # These are necessary (but not sufficient) requirements for the WCS. - if "CRVAL1" not in hdu.header or "CRVAL2" not in hdu.header: - return None - if "CRPIX1" not in hdu.header or "CRPIX2" not in hdu.header: - return None - - curr_wcs = WCS(hdu.header) - if curr_wcs is None: - return None - if curr_wcs.naxis != 2: - return None - - return curr_wcs + return dump(workunit_dict) def raw_image_to_hdu(img, wcs=None): @@ -357,9 +385,7 @@ def raw_image_to_hdu(img, wcs=None): # If the WCS is given, copy each entry into the header. if wcs is not None: - wcs_header = wcs.to_header() - for key in wcs_header: - hdu.header[key] = wcs_header[key] + append_wcs_to_hdu_header(wcs, hdu.header) # Set the time stamp. hdu.header["MJD"] = img.obstime diff --git a/tests/test_wcs_utils.py b/tests/test_wcs_utils.py index d105468e..0d0434a3 100644 --- a/tests/test_wcs_utils.py +++ b/tests/test_wcs_utils.py @@ -2,26 +2,75 @@ import astropy.coordinates import astropy.units +from astropy.wcs import WCS +from astropy.io import fits -import kbmod.analysis.wcs_utils +from kbmod.wcs_utils import * + + +class test_wcs_conversion(unittest.TestCase): + def setUp(self): + self.header_dict = { + "WCSAXES": 2, + "CTYPE1": "RA---TAN-SIP", + "CTYPE2": "DEC--TAN-SIP", + "CRVAL1": 200.614997245422, + "CRVAL2": -7.78878863332778, + "CRPIX1": 1033.934327, + "CRPIX2": 2043.548284, + } + self.wcs = WCS(self.header_dict) + self.header = self.wcs.to_header() + + def test_wcs_from_dict(self): + # The base dictionary is good. + self.assertIsNotNone(wcs_from_dict(self.header_dict)) + + # Remove a required word and fail. + del self.header_dict["CRVAL1"] + self.assertIsNone(wcs_from_dict(self.header_dict)) + + def test_extract_wcs_from_hdu_header(self): + # The base dictionary is good. + self.assertIsNotNone(extract_wcs_from_hdu_header(self.header)) + + # Remove a required word and fail. + del self.header["CRVAL1"] + self.assertIsNone(extract_wcs_from_hdu_header(self.header)) + + def test_wcs_to_dict(self): + new_dict = wcs_to_dict(self.wcs) + for key in self.header_dict: + self.assertTrue(key in new_dict) + self.assertAlmostEqual(new_dict[key], self.header_dict[key]) + + def test_append_wcs_to_hdu_header(self): + pri = fits.PrimaryHDU() + self.assertFalse("CRVAL1" in pri.header) + self.assertFalse("CRVAL2" in pri.header) + self.assertFalse("CRPIX1" in pri.header) + self.assertFalse("CRPIX2" in pri.header) + + append_wcs_to_hdu_header(self.wcs, pri.header) + for key in self.header_dict: + self.assertTrue(key in pri.header) + self.assertAlmostEqual(pri.header[key], self.header_dict[key]) class test_construct_wcs_tangent_projection(unittest.TestCase): def test_requires_parameters(self): with self.assertRaises(TypeError): - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection() + wcs = construct_wcs_tangent_projection() def test_only_required_parameter(self): - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection(None) + wcs = construct_wcs_tangent_projection(None) self.assertIsNotNone(wcs) def test_one_pixel(self): ref_val = astropy.coordinates.SkyCoord( ra=0 * astropy.units.deg, dec=0 * astropy.units.deg, frame="icrs" ) - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection( - ref_val, img_shape=[1, 1], image_fov=3.5 * astropy.units.deg - ) + wcs = construct_wcs_tangent_projection(ref_val, img_shape=[1, 1], image_fov=3.5 * astropy.units.deg) self.assertIsNotNone(wcs) skyval = wcs.pixel_to_world(0, 0) refsep = ref_val.separation(skyval).to(astropy.units.deg).value @@ -31,9 +80,7 @@ def test_two_pixel(self): ref_val = astropy.coordinates.SkyCoord( ra=0 * astropy.units.deg, dec=0 * astropy.units.deg, frame="icrs" ) - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection( - ref_val, img_shape=[2, 2], image_fov=3.5 * astropy.units.deg - ) + wcs = construct_wcs_tangent_projection(ref_val, img_shape=[2, 2], image_fov=3.5 * astropy.units.deg) self.assertIsNotNone(wcs) skyval = wcs.pixel_to_world(0.5, 0.5) refsep = ref_val.separation(skyval).to(astropy.units.deg).value @@ -45,11 +92,11 @@ def test_image_field_of_view(self): ref_val = astropy.coordinates.SkyCoord( ra=0 * astropy.units.deg, dec=0 * astropy.units.deg, frame="icrs" ) - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection( + wcs = construct_wcs_tangent_projection( ref_val, img_shape=[16, 16], image_fov=fov_wanted, solve_for_image_fov=True ) self.assertIsNotNone(wcs) - fov_actual = kbmod.analysis.wcs_utils.calc_actual_image_fov(wcs)[0] + fov_actual = calc_actual_image_fov(wcs)[0] self.assertAlmostEqual(fov_wanted.value, fov_actual.value, places=8) def test_image_field_of_view_wide(self): @@ -60,11 +107,11 @@ def test_image_field_of_view_wide(self): ref_val = astropy.coordinates.SkyCoord( ra=0 * astropy.units.deg, dec=0 * astropy.units.deg, frame="icrs" ) - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection( + wcs = construct_wcs_tangent_projection( ref_val, img_shape=[64, 32], image_fov=fov_wanted[0], solve_for_image_fov=True ) self.assertIsNotNone(wcs) - fov_actual = kbmod.analysis.wcs_utils.calc_actual_image_fov(wcs) + fov_actual = calc_actual_image_fov(wcs) self.assertAlmostEqual(fov_wanted[0].value, fov_actual[0].value, places=8) self.assertAlmostEqual(fov_wanted[1].value, fov_actual[1].value, places=8) @@ -76,10 +123,14 @@ def test_image_field_of_view_tall(self): ref_val = astropy.coordinates.SkyCoord( ra=0 * astropy.units.deg, dec=0 * astropy.units.deg, frame="icrs" ) - wcs = kbmod.analysis.wcs_utils.construct_wcs_tangent_projection( + wcs = construct_wcs_tangent_projection( ref_val, img_shape=[32, 64], image_fov=fov_wanted[0], solve_for_image_fov=True ) self.assertIsNotNone(wcs) - fov_actual = kbmod.analysis.wcs_utils.calc_actual_image_fov(wcs) + fov_actual = calc_actual_image_fov(wcs) self.assertAlmostEqual(fov_wanted[0].value, fov_actual[0].value, places=8) self.assertAlmostEqual(fov_wanted[1].value, fov_actual[1].value, places=8) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index 51ba0121..b1c4b7fd 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -5,6 +5,7 @@ from pathlib import Path import tempfile import unittest +import warnings from kbmod.configuration import SearchConfiguration import kbmod.search as kb @@ -61,6 +62,9 @@ def setUp(self): "CUNIT2A": "PIXEL ", } self.wcs = WCS(header_dict) + self.per_image_wcs = per_image_wcs = [ + (self.wcs if i % 2 == 0 else None) for i in range(self.num_images) + ] def test_create(self): work = WorkUnit(self.im_stack, self.config) @@ -68,11 +72,16 @@ def test_create(self): self.assertEqual(work.config["im_filepath"], "Here") self.assertEqual(work.config["num_obs"], 5) self.assertIsNone(work.wcs) + self.assertEqual(len(work), self.num_images) + for i in range(self.num_images): + self.assertIsNone(work.get_wcs(i)) # Create with a global WCS work2 = WorkUnit(self.im_stack, self.config, self.wcs) self.assertEqual(work2.im_stack.img_count(), 5) self.assertIsNotNone(work2.wcs) + for i in range(self.num_images): + self.assertIsNotNone(work2.get_wcs(i)) # Mismatch with the number of WCS. self.assertRaises( @@ -84,6 +93,22 @@ def test_create(self): [self.wcs, self.wcs, self.wcs], ) + # Create with per-image WCS + per_image_wcs = [self.wcs] * self.num_images + work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=per_image_wcs) + self.assertIsNone(work3.wcs) + for i in range(self.num_images): + self.assertIsNotNone(work3.get_wcs(i)) + + # Create with both global and per-image WCS. Check that a get triggers a warning. + work4 = WorkUnit(self.im_stack, self.config, self.wcs, per_image_wcs) + self.assertIsNotNone(work4.wcs) + with warnings.catch_warnings(record=True) as wrn: + warnings.simplefilter("always") + current = work4.get_wcs(0) + self.assertTrue("Both a global and per-image WCS given." in str(wrn[-1].message)) + self.assertIsNotNone(current) + def test_create_from_dict(self): for use_python_types in [True, False]: if use_python_types: @@ -97,6 +122,8 @@ def test_create_from_dict(self): "var_imgs": [self.images[i].get_variance().image for i in range(self.num_images)], "msk_imgs": [self.images[i].get_mask().image for i in range(self.num_images)], "psfs": [np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim())) for p in self.p], + "per_image_wcs": self.per_image_wcs, + "wcs": self.wcs, } else: work_unit_dict = { @@ -109,6 +136,8 @@ def test_create_from_dict(self): "var_imgs": [self.images[i].get_variance() for i in range(self.num_images)], "msk_imgs": [self.images[i].get_mask() for i in range(self.num_images)], "psfs": self.p, + "per_image_wcs": self.per_image_wcs, + "wcs": self.wcs, } with self.subTest(i=use_python_types): @@ -116,6 +145,7 @@ def test_create_from_dict(self): self.assertEqual(work.im_stack.img_count(), self.num_images) self.assertEqual(work.im_stack.get_width(), self.width) self.assertEqual(work.im_stack.get_height(), self.height) + self.assertIsNotNone(work.wcs) for i in range(self.num_images): layered1 = work.im_stack.get_single_image(i) layered2 = self.im_stack.get_single_image(i) @@ -124,6 +154,7 @@ def test_create_from_dict(self): self.assertTrue(layered1.get_variance().l2_allclose(layered2.get_variance(), 0.01)) self.assertTrue(layered1.get_mask().l2_allclose(layered2.get_mask(), 0.01)) self.assertEqual(layered1.get_obstime(), layered2.get_obstime()) + self.assertEqual(work.per_image_wcs[i] is None, i % 2 == 1) self.assertTrue(type(work.config) is SearchConfiguration) self.assertEqual(work.config["im_filepath"], "Here") @@ -137,10 +168,8 @@ def test_save_and_load_fits(self): # Unable to load non-existent file. self.assertRaises(ValueError, WorkUnit.from_fits, file_path) - # Write out the existing WorkUnit with a per image wcs for the - # even entries. - per_image_wcs = [(self.wcs if i % 2 == 0 else None) for i in range(self.num_images)] - work = WorkUnit(self.im_stack, self.config, self.wcs, per_image_wcs) + # Write out the existing WorkUnit with a per image wcs for the even entries. + work = WorkUnit(self.im_stack, self.config, self.wcs, self.per_image_wcs) work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) @@ -189,13 +218,14 @@ def test_save_and_load_fits(self): self.assertIsNone(work2.config["repeated_flag_keys"]) def test_to_from_yaml(self): - work = WorkUnit(self.im_stack, self.config) + work = WorkUnit(self.im_stack, self.config, self.wcs, self.per_image_wcs) yaml_str = work.to_yaml() work2 = WorkUnit.from_yaml(yaml_str) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertEqual(work2.im_stack.get_width(), self.width) self.assertEqual(work2.im_stack.get_height(), self.height) + self.assertIsNotNone(work2.wcs) for i in range(self.num_images): layered1 = work2.im_stack.get_single_image(i) layered2 = self.im_stack.get_single_image(i) @@ -204,6 +234,7 @@ def test_to_from_yaml(self): self.assertTrue(layered1.get_variance().l2_allclose(layered2.get_variance(), 0.01)) self.assertTrue(layered1.get_mask().l2_allclose(layered2.get_mask(), 0.01)) self.assertAlmostEqual(layered1.get_obstime(), layered2.get_obstime()) + self.assertEqual(work2.per_image_wcs[i] is None, i % 2 == 1) # Check that we read in the configuration values correctly. self.assertEqual(work2.config["im_filepath"], "Here")