Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serializing WorkUnit to YAML string #391

Merged
merged 10 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/kbmod/analysis/wcs_utils.py → src/kbmod/wcs_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""A collection of utility functions for working with WCS in KBMOD."""

import astropy.coordinates
import astropy.units
import astropy.wcs
Expand Down Expand Up @@ -276,3 +278,99 @@ def _calc_actual_image_fov(wcs, ref_pixel, image_size):
)
refsep2 = [skyvallist[0].separation(skyvallist[1]), skyvallist[2].separation(skyvallist[3])]
return refsep2


def extract_wcs_from_hdu_header(header):
"""Read an WCS from the an HDU header and do basic validity checking.

Parameters
----------
header : `astropy.io.fits.Header`
The header from which to read the data.

Returns
--------
curr_wcs : `astropy.wcs.WCS`
The WCS or None if it does not exist.
"""
# Check that we have (at minimum) the CRVAL and CRPIX keywords.
# These are necessary (but not sufficient) requirements for the WCS.
if "CRVAL1" not in header or "CRVAL2" not in header:
return None
if "CRPIX1" not in header or "CRPIX2" not in header:
return None

curr_wcs = astropy.wcs.WCS(header)
if curr_wcs is None:
return None
if curr_wcs.naxis != 2:
return None

return curr_wcs


def wcs_from_dict(data):
"""Extract a WCS from a fictionary of the HDU header keys/values.
Performs very basic validity checking.

Parameters
----------
data : `dict`
A dictionary containing the WCS header information.

Returns
-------
wcs : `astropy.wcs.WCS`
The WCS to convert.
"""
# Check that we have (at minimum) the CRVAL and CRPIX keywords.
# These are necessary (but not sufficient) requirements for the WCS.
if "CRVAL1" not in data or "CRVAL2" not in data:
return None
if "CRPIX1" not in data or "CRPIX2" not in data:
return None

curr_wcs = astropy.wcs.WCS(data)
if curr_wcs is None:
return None
if curr_wcs.naxis != 2:
return None

return curr_wcs


def append_wcs_to_hdu_header(wcs, header):
"""Append the WCS fields to an existing HDU header.

Parameters
----------
wcs : `astropy.wcs.WCS`
The WCS to use.
header : `astropy.io.fits.Header`
The header to which to append the data.
"""
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
header[key] = wcs_header[key]


def wcs_to_dict(wcs):
"""Convert a WCS to a dictionary (via a FITS header).

Parameters
----------
wcs : `astropy.wcs.WCS`
The WCS to convert.

Returns
-------
result : `dict`
A dictionary containing the WCS header information.
"""
result = {}
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
result[key] = wcs_header[key]
return result
215 changes: 182 additions & 33 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
import numpy as np
from pathlib import Path
import warnings
from yaml import dump, safe_load

from kbmod.configuration import SearchConfiguration
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage
from kbmod.wcs_utils import (
append_wcs_to_hdu_header,
extract_wcs_from_hdu_header,
wcs_from_dict,
wcs_to_dict,
)


class WorkUnit:
Expand All @@ -24,10 +31,11 @@ class WorkUnit:
config : `kbmod.configuration.SearchConfiguration`
The configuration for the KBMOD run.
wcs : `astropy.wcs.WCS`
A gloabl WCS for all images in the WorkUnit.
A global WCS for all images in the WorkUnit. Only exists
if all images have been projected to same pixel space.
per_image_wcs : `list`
A list with one WCS for each image in the WorkUnit. Used for when
the images have not been standardized to the same pixel space.
the images have *not* been standardized to the same pixel space.
"""

def __init__(self, im_stack=None, config=None, wcs=None, per_image_wcs=None):
Expand All @@ -42,6 +50,38 @@ def __init__(self, im_stack=None, config=None, wcs=None, per_image_wcs=None):
raise ValueError("Incorrect number of WCS provided.")
self.per_image_wcs = per_image_wcs

def __len__(self):
"""Returns the size of the WorkUnit in number of images."""
return self.im_stack.img_count()

def get_wcs(self, img_num):
"""Return the WCS for the a given image. Alway prioritizes
a global WCS if one exits.

Parameters
----------
img_num : `int`
The number of the image.

Returns
-------
wcs : `astropy.wcs.WCS`
The image's WCS if one exists. Otherwise None.

Raises
------
IndexError if an invalid index is given.
"""
if img_num < 0 or img_num >= self.im_stack.img_count():
raise IndexError(f"Invalid image number {img_num}")

if self.wcs is not None:
if self.per_image_wcs[img_num] is not None:
warnings.warn("Both a global and per-image WCS given. Using global WCS.", Warning)
return self.wcs

return self.per_image_wcs[img_num]

@classmethod
def from_fits(cls, filename):
"""Create a WorkUnit from a single FITS file.
Expand Down Expand Up @@ -84,7 +124,7 @@ def from_fits(cls, filename):
# since the primary header does not have an image.
with warnings.catch_warnings():
warnings.simplefilter("ignore", AstropyWarning)
global_wcs = extract_wcs(hdul[0])
global_wcs = extract_wcs_from_hdu_header(hdul[0].header)

# Read the size and order information from the primary header.
num_images = hdul[0].header["NUMIMG"]
Expand All @@ -98,7 +138,7 @@ def from_fits(cls, filename):
per_image_wcs = []
for i in range(num_images):
# Extract the per-image WCS if one exists.
per_image_wcs.append(extract_wcs(hdul[f"SCI_{i}"]))
per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"SCI_{i}"].header))

# Read in science, variance, and mask layers.
sci = hdu_to_raw_image(hdul[f"SCI_{i}"])
Expand All @@ -114,6 +154,108 @@ def from_fits(cls, filename):
result = WorkUnit(im_stack=im_stack, config=config, wcs=global_wcs, per_image_wcs=per_image_wcs)
return result

@classmethod
def from_dict(cls, workunit_dict):
"""Create a WorkUnit from a combined dictionary.

Parameters
----------
workunit_dict : `dict`
The dictionary of information.

Returns
-------
`WorkUnit`

Raises
------
Raises a ``ValueError`` for any invalid parameters.
"""
num_images = workunit_dict["num_images"]
width = workunit_dict["width"]
height = workunit_dict["height"]
if width <= 0 or height <= 0:
raise ValueError(f"Illegal image dimensions width={width}, height={height}")

# Load the configuration supporting both dictionary and SearchConfiguration.
if type(workunit_dict["config"]) is dict:
config = SearchConfiguration.from_dict(workunit_dict["config"])
elif type(workunit_dict["config"]) is SearchConfiguration:
config = workunit_dict["config"]
else:
raise ValueError("Unrecognized type for WorkUnit config parameter.")

# Load the global WCS if one exists.
if "wcs" in workunit_dict:
if type(workunit_dict["wcs"]) is dict:
global_wcs = wcs_from_dict(workunit_dict["wcs"])
else:
global_wcs = workunit_dict["wcs"]
else:
global_wcs = None

imgs = []
per_image_wcs = []
for i in range(num_images):
obs_time = workunit_dict["times"][i]

if type(workunit_dict["sci_imgs"][i]) is RawImage:
sci_img = workunit_dict["sci_imgs"][i]
else:
sci_arr = np.array(workunit_dict["sci_imgs"][i], dtype=np.float32).reshape(height, width)
sci_img = RawImage(img=sci_arr, obs_time=obs_time)

if type(workunit_dict["var_imgs"][i]) is RawImage:
var_img = workunit_dict["var_imgs"][i]
else:
var_arr = np.array(workunit_dict["var_imgs"][i], dtype=np.float32).reshape(height, width)
var_img = RawImage(img=var_arr, obs_time=obs_time)

# Masks are optional.
if workunit_dict["msk_imgs"][i] is None:
msk_arr = np.zeros(height, width)
msk_img = RawImage(img=msk_arr, obs_time=obs_time)
elif type(workunit_dict["msk_imgs"][i]) is RawImage:
msk_img = workunit_dict["msk_imgs"][i]
else:
msk_arr = np.array(workunit_dict["msk_imgs"][i], dtype=np.float32).reshape(height, width)
msk_img = RawImage(img=msk_arr, obs_time=obs_time)

# PSFs are optional.
if workunit_dict["psfs"][i] is None:
p = PSF()
elif type(workunit_dict["psfs"][i]) is PSF:
p = workunit_dict["psfs"][i]
else:
p = PSF(np.array(workunit_dict["psfs"][i], dtype=np.float32))

imgs.append(LayeredImage(sci_img, var_img, msk_img, p))

# Read a per_image_wcs if one exists.
current_wcs = workunit_dict["per_image_wcs"][i]
if type(current_wcs) is dict:
current_wcs = wcs_from_dict(current_wcs)
per_image_wcs.append(current_wcs)

im_stack = ImageStack(imgs)
return WorkUnit(im_stack=im_stack, config=config, wcs=global_wcs, per_image_wcs=per_image_wcs)

@classmethod
def from_yaml(cls, work_unit):
"""Load a configuration from a YAML string.

Parameters
----------
work_unit : `str` or `_io.TextIOWrapper`
The serialized YAML data.

Raises
------
Raises a ``ValueError`` for any invalid parameters.
"""
yaml_dict = safe_load(work_unit)
return WorkUnit.from_dict(yaml_dict)

def to_fits(self, filename, overwrite=False):
"""Write the WorkUnit to a single FITS file.

Expand Down Expand Up @@ -144,9 +286,7 @@ def to_fits(self, filename, overwrite=False):

# If the global WCS exists, append the corresponding keys.
if self.wcs is not None:
wcs_header = self.wcs.to_header()
for key in wcs_header:
pri.header[key] = wcs_header[key]
append_wcs_to_hdu_header(self.wcs, pri.header)

hdul.append(pri)

Expand Down Expand Up @@ -185,34 +325,45 @@ def to_fits(self, filename, overwrite=False):

hdul.writeto(filename)

def to_yaml(self):
"""Serialize the WorkUnit as a YAML string.

def extract_wcs(hdu):
"""Read an WCS from the header and does basic validity checking.
Returns
-------
result : `str`
The serialized YAML string.
"""
workunit_dict = {
"num_images": self.im_stack.img_count(),
"width": self.im_stack.get_width(),
"height": self.im_stack.get_height(),
"config": self.config._params,
"wcs": wcs_to_dict(self.wcs),
# Per image data
"times": [],
"sci_imgs": [],
"var_imgs": [],
"msk_imgs": [],
"psfs": [],
"per_image_wcs": [],
}

# Fill in the per-image data.
for i in range(self.im_stack.img_count()):
layered = self.im_stack.get_single_image(i)
workunit_dict["times"].append(layered.get_obstime())
p = layered.get_psf()

Parameters
----------
hdu : An astropy HDU (Image or Primary)
The extension
workunit_dict["sci_imgs"].append(layered.get_science().image.tolist())
workunit_dict["var_imgs"].append(layered.get_variance().image.tolist())
workunit_dict["msk_imgs"].append(layered.get_mask().image.tolist())

Returns
--------
curr_wcs : `astropy.wcs.WCS`
The WCS or None if it does not exist.
"""
# Check that we have (at minimum) the CRVAL and CRPIX keywords.
# These are necessary (but not sufficient) requirements for the WCS.
if "CRVAL1" not in hdu.header or "CRVAL2" not in hdu.header:
return None
if "CRPIX1" not in hdu.header or "CRPIX2" not in hdu.header:
return None
psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim()))
workunit_dict["psfs"].append(psf_array.tolist())

curr_wcs = WCS(hdu.header)
if curr_wcs is None:
return None
if curr_wcs.naxis != 2:
return None
workunit_dict["per_image_wcs"].append(wcs_to_dict(self.per_image_wcs[i]))

return curr_wcs
return dump(workunit_dict)


def raw_image_to_hdu(img, wcs=None):
Expand All @@ -234,9 +385,7 @@ def raw_image_to_hdu(img, wcs=None):

# If the WCS is given, copy each entry into the header.
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
hdu.header[key] = wcs_header[key]
append_wcs_to_hdu_header(wcs, hdu.header)

# Set the time stamp.
hdu.header["MJD"] = img.obstime
Expand Down
Loading