Skip to content

Commit

Permalink
Add WCS to YAML
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Dec 14, 2023
1 parent ab17022 commit 50a5551
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 60 deletions.
98 changes: 98 additions & 0 deletions src/kbmod/analysis/wcs_utils.py → src/kbmod/wcs_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""A collection of utility functions for working with WCS in KBMOD."""

import astropy.coordinates
import astropy.units
import astropy.wcs
Expand Down Expand Up @@ -276,3 +278,99 @@ def _calc_actual_image_fov(wcs, ref_pixel, image_size):
)
refsep2 = [skyvallist[0].separation(skyvallist[1]), skyvallist[2].separation(skyvallist[3])]
return refsep2


def extract_wcs_from_hdu_header(header):
"""Read an WCS from the an HDU header and do basic validity checking.
Parameters
----------
header : `astropy.io.fits.Header`
The header from which to read the data.
Returns
--------
curr_wcs : `astropy.wcs.WCS`
The WCS or None if it does not exist.
"""
# Check that we have (at minimum) the CRVAL and CRPIX keywords.
# These are necessary (but not sufficient) requirements for the WCS.
if "CRVAL1" not in header or "CRVAL2" not in header:
return None
if "CRPIX1" not in header or "CRPIX2" not in header:
return None

curr_wcs = astropy.wcs.WCS(header)
if curr_wcs is None:
return None
if curr_wcs.naxis != 2:
return None

return curr_wcs


def wcs_from_dict(data):
"""Extract a WCS from a fictionary of the HDU header keys/values.
Performs very basic validity checking.
Parameters
----------
data : `dict`
A dictionary containing the WCS header information.
Returns
-------
wcs : `astropy.wcs.WCS`
The WCS to convert.
"""
# Check that we have (at minimum) the CRVAL and CRPIX keywords.
# These are necessary (but not sufficient) requirements for the WCS.
if "CRVAL1" not in data or "CRVAL2" not in data:
return None
if "CRPIX1" not in data or "CRPIX2" not in data:
return None

curr_wcs = astropy.wcs.WCS(data)
if curr_wcs is None:
return None
if curr_wcs.naxis != 2:
return None

return curr_wcs


def append_wcs_to_hdu_header(wcs, header):
"""Append the WCS fields to an existing HDU header.
Parameters
----------
wcs : `astropy.wcs.WCS`
The WCS to use.
header : `astropy.io.fits.Header`
The header to which to append the data.
"""
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
header[key] = wcs_header[key]


def wcs_to_dict(wcs):
"""Convert a WCS to a dictionary (via a FITS header).
Parameters
----------
wcs : `astropy.wcs.WCS`
The WCS to convert.
Returns
-------
result : `dict`
A dictionary containing the WCS header information.
"""
result = {}
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
result[key] = wcs_header[key]
return result
106 changes: 66 additions & 40 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@

from kbmod.configuration import SearchConfiguration
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage
from kbmod.wcs_utils import (
append_wcs_to_hdu_header,
extract_wcs_from_hdu_header,
wcs_from_dict,
wcs_to_dict,
)


class WorkUnit:
Expand All @@ -25,10 +31,11 @@ class WorkUnit:
config : `kbmod.configuration.SearchConfiguration`
The configuration for the KBMOD run.
wcs : `astropy.wcs.WCS`
A gloabl WCS for all images in the WorkUnit.
A global WCS for all images in the WorkUnit. Only exists
if all images have been projected to same pixel space.
per_image_wcs : `list`
A list with one WCS for each image in the WorkUnit. Used for when
the images have not been standardized to the same pixel space.
the images have *not* been standardized to the same pixel space.
"""

def __init__(self, im_stack=None, config=None, wcs=None, per_image_wcs=None):
Expand All @@ -43,6 +50,38 @@ def __init__(self, im_stack=None, config=None, wcs=None, per_image_wcs=None):
raise ValueError("Incorrect number of WCS provided.")
self.per_image_wcs = per_image_wcs

def __len__(self):
"""Returns the size of the WorkUnit in number of images."""
return self.im_stack.img_count()

def get_wcs(self, img_num):
"""Return the WCS for the a given image. Alway prioritizes
a global WCS if one exits.
Parameters
----------
img_num : `int`
The number of the image.
Returns
-------
wcs : `astropy.wcs.WCS`
The image's WCS if one exists. Otherwise None.
Raises
------
IndexError if an invalid index is given.
"""
if img_num < 0 or img_num >= self.im_stack.img_count():
raise IndexError(f"Invalid image number {img_num}")

if self.wcs is not None:
if self.per_image_wcs[img_num] is not None:
warnings.warn("Both a global and per-image WCS given. Using global WCS.", Warning)
return self.wcs

return self.per_image_wcs[img_num]

@classmethod
def from_fits(cls, filename):
"""Create a WorkUnit from a single FITS file.
Expand Down Expand Up @@ -85,7 +124,7 @@ def from_fits(cls, filename):
# since the primary header does not have an image.
with warnings.catch_warnings():
warnings.simplefilter("ignore", AstropyWarning)
global_wcs = extract_wcs(hdul[0])
global_wcs = extract_wcs_from_hdu_header(hdul[0].header)

# Read the size and order information from the primary header.
num_images = hdul[0].header["NUMIMG"]
Expand All @@ -99,7 +138,7 @@ def from_fits(cls, filename):
per_image_wcs = []
for i in range(num_images):
# Extract the per-image WCS if one exists.
per_image_wcs.append(extract_wcs(hdul[f"SCI_{i}"]))
per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"SCI_{i}"].header))

# Read in science, variance, and mask layers.
sci = hdu_to_raw_image(hdul[f"SCI_{i}"])
Expand Down Expand Up @@ -146,7 +185,17 @@ def from_dict(cls, workunit_dict):
else:
raise ValueError("Unrecognized type for WorkUnit config parameter.")

# Load the global WCS if one exists.
if "wcs" in workunit_dict:
if type(workunit_dict["wcs"]) is dict:
global_wcs = wcs_from_dict(workunit_dict["wcs"])
else:
global_wcs = workunit_dict["wcs"]
else:
global_wcs = None

imgs = []
per_image_wcs = []
for i in range(num_images):
obs_time = workunit_dict["times"][i]

Expand Down Expand Up @@ -182,8 +231,14 @@ def from_dict(cls, workunit_dict):

imgs.append(LayeredImage(sci_img, var_img, msk_img, p))

# Read a per_image_wcs if one exists.
current_wcs = workunit_dict["per_image_wcs"][i]
if type(current_wcs) is dict:
current_wcs = wcs_from_dict(current_wcs)
per_image_wcs.append(current_wcs)

im_stack = ImageStack(imgs)
return WorkUnit(im_stack=im_stack, config=config)
return WorkUnit(im_stack=im_stack, config=config, wcs=global_wcs, per_image_wcs=per_image_wcs)

@classmethod
def from_yaml(cls, work_unit):
Expand Down Expand Up @@ -231,9 +286,7 @@ def to_fits(self, filename, overwrite=False):

# If the global WCS exists, append the corresponding keys.
if self.wcs is not None:
wcs_header = self.wcs.to_header()
for key in wcs_header:
pri.header[key] = wcs_header[key]
append_wcs_to_hdu_header(self.wcs, pri.header)

hdul.append(pri)

Expand Down Expand Up @@ -285,12 +338,14 @@ def to_yaml(self):
"width": self.im_stack.get_width(),
"height": self.im_stack.get_height(),
"config": self.config._params,
"wcs": wcs_to_dict(self.wcs),
# Per image data
"times": [],
"sci_imgs": [],
"var_imgs": [],
"msk_imgs": [],
"psfs": [],
"per_image_wcs": [],
}

# Fill in the per-image data.
Expand All @@ -306,36 +361,9 @@ def to_yaml(self):
psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim()))
workunit_dict["psfs"].append(psf_array.tolist())

return dump(workunit_dict)

workunit_dict["per_image_wcs"].append(wcs_to_dict(self.per_image_wcs[i]))

def extract_wcs(hdu):
"""Read an WCS from the header and does basic validity checking.
Parameters
----------
hdu : An astropy HDU (Image or Primary)
The extension
Returns
--------
curr_wcs : `astropy.wcs.WCS`
The WCS or None if it does not exist.
"""
# Check that we have (at minimum) the CRVAL and CRPIX keywords.
# These are necessary (but not sufficient) requirements for the WCS.
if "CRVAL1" not in hdu.header or "CRVAL2" not in hdu.header:
return None
if "CRPIX1" not in hdu.header or "CRPIX2" not in hdu.header:
return None

curr_wcs = WCS(hdu.header)
if curr_wcs is None:
return None
if curr_wcs.naxis != 2:
return None

return curr_wcs
return dump(workunit_dict)


def raw_image_to_hdu(img, wcs=None):
Expand All @@ -357,9 +385,7 @@ def raw_image_to_hdu(img, wcs=None):

# If the WCS is given, copy each entry into the header.
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
hdu.header[key] = wcs_header[key]
append_wcs_to_hdu_header(wcs, hdu.header)

# Set the time stamp.
hdu.header["MJD"] = img.obstime
Expand Down
Loading

0 comments on commit 50a5551

Please sign in to comment.