Skip to content

Commit

Permalink
Implement toLayeredImage, toImageStack and run method. Cleanup.
Browse files Browse the repository at this point in the history
Implement the methods required to run KBMOD on ImageCollection.
Cleanup ImageCollection behaviour:
*) return image collection when indexed by lists, arrays and slices
*) return Row when indexed by integer
*) return Table when sliced by columns
*) Rename the exts to processable. Alias it to a property so that
   each Standardizer can implement its own internal structure the
   way it wants (but also because I was too lazy to rename everything)
*) Fix documentation
*) Move WCS and BBOX as properties to a Standardizer - if that's
   where we need to explain why they are special that's where they
   need to live. Make them an abstractproperty and demand that the
   Standardizers return a list of None's if need be.
*) Fix forceStandardizer keyword (again).
*) Add toLayeredImage as an abstract method to the Standardizers
   Implement them for the three example Standardizers we have.
*) Add toImageStack as an abstract method to the standardizers.
   Implment them in ImageCollection
*) Add run method prototype to ImageCollection to showcase how
   we can neatly integrate with the ImageCollection to execute
   KBMOD runs.

Write an example python script showcasing most of this functionality.

TODO: tests, unittests, integrationtests all the tests.
  • Loading branch information
DinoBektesevic committed Jul 7, 2023
1 parent 088f66c commit 7332543
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 96 deletions.
6 changes: 3 additions & 3 deletions src/kbmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
except ImportError:
warnings.warn("Unable to determine the package version. " "This is likely a broken installation.")

from .standardizers import *
from . import (
analysis,
analysis_utils,
Expand All @@ -15,7 +14,8 @@
jointfit_functions,
result_list,
run_search,
standardizer,
)


from .standardizers import *
from .standardizer import Standardizer
from .image_collection import ImageCollection
247 changes: 225 additions & 22 deletions src/kbmod/image_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,24 @@
import os
import glob
import json
import time

import astropy.units as u
from astropy.io import fits
from astropy.time import Time
from astropy.table import Table
from astropy.wcs import WCS
from astropy.utils import isiterable
from astropy.coordinates import SkyCoord

import numpy as np

from kbmod.file_utils import FileUtils
from kbmod.search import pixel_pos, layered_image
from kbmod.search import image_stack, stack_search
from kbmod.standardizer import Standardizer
from kbmod.analysis_utils import PostProcess


__all__ = ["ImageCollection", ]


class ImageCollection:
Expand Down Expand Up @@ -134,14 +140,15 @@ def __init__(self, metadata, standardizers=None):

if standardizers is not None:
self.data.meta["n_entries"] = len(standardizers)
self._standardizers = standardizers
self._standardizers = np.array(standardizers)
elif metadata.meta and "n_entries" in metadata.meta:
n_entries = metadata.meta["n_entries"]
self._standardizers = [None]*n_entries
self._standardizers = np.full((n_entries, ), None)
else:
n_entries = len(np.unique(metadata["location"]))
self.data.meta["n_entries"] = n_entries
self._standardizers = [None]*n_entries
self._standardizers = np.full((n_entries, ), None)
#self._standardizers = [None]*n_entries

# hidden indices that track the unravelled lookup to standardizer
# extension index. I should imagine there's a better than double-loop
Expand All @@ -157,7 +164,9 @@ def __init__(self, metadata, standardizers=None):
no_std_map = True
self.data["ext_idx"] = [None]*len(self.data)

if standardizers and no_std_map:
# standardizers are not Falsy - empty lists, Nones, empty tuples, empty
# arrays etc...
if (standardizers is not None and any(standardizers)) and no_std_map:
std_idxs, ext_idxs = [], []
for i, stdFits in enumerate(standardizers):
for j, ext in enumerate(stdFits.exts):
Expand Down Expand Up @@ -195,7 +204,7 @@ def read(cls, *args, format=None, units=None, descriptions=None, **kwargs):
return cls(metadata)

@classmethod
def _fromStandardizers(cls, standardizers, meta=None):
def fromStandardizers(cls, standardizers, meta=None):
"""Create ImageCollection from a collection `Standardizers`.
The `Standardizer` is "unravelled", i.e. the shared metadata is
Expand Down Expand Up @@ -224,7 +233,7 @@ def _fromStandardizers(cls, standardizers, meta=None):
# ButlerStd.stdMeta. Everything that is an iterable, except for a
# string because that could be a location key?
unravelColumns = [key for key, val in stdMeta.items() if isiterable(val) and not isinstance(val, str)]
for j, ext in enumerate(stdFits.exts):
for j, ext in enumerate(stdFits.processable):
row = {}
for key in stdMeta.keys():
if key in unravelColumns:
Expand Down Expand Up @@ -263,7 +272,7 @@ def _fromFilepaths(cls, filepaths, forceStandardizer, **kwargs):
Standardizer.fromFile(path=path, forceStandardizer=forceStandardizer, **kwargs)
for path in filepaths
]
return cls._fromStandardizers(standardizers)
return cls.fromStandardizers(standardizers)

@classmethod
def _fromDir(cls, path, recursive, forceStandardizer, **kwargs):
Expand Down Expand Up @@ -357,7 +366,7 @@ def fromDatasetRefs(cls, butler, refs, **kwargs):
standardizer_cls = Standardizer.get(standardizer="ButlerStandardizer")
standardizer = standardizer_cls(butler, refs, **kwargs)
meta = {"root": butler.datastore.root.geturl(), "n_entries": len(list(refs))}
return cls._fromStandardizers([standardizer, ], meta=meta)
return cls.fromStandardizers([standardizer, ], meta=meta)

def fromAQueryTable(self): # ? TBD
pass
Expand All @@ -372,7 +381,12 @@ def __repr__(self):
return repr(self.data).replace("Table", "ImageInfoSet")

def __getitem__(self, key):
return self.data[key]
if isinstance(key, (int, str, np.integer)):
return self.data[self._userColumns][key]
elif isinstance(key, (list, np.ndarray, slice)):
return self.__class__(self.data[key], standardizers=self._standardizers[key])
else:
return self.data[key]

def __setitem__(self, key, val):
self.data[key] = val
Expand Down Expand Up @@ -417,7 +431,7 @@ def standardizers(self):
"""A list of used standardizer names."""
return self._standardizer_names

def _get_standardizer(self, index):
def get_standardizer(self, index, **kwargs):
"""Get the standardizer and extension index for the selected row of the
unravelled metadata table.
Expand All @@ -429,6 +443,8 @@ def _get_standardizer(self, index):
index : `int`
Index, as it appears in the unravelled table of metadata
properties.
**kwargs : `dict`
Keyword arguments are passed onto the Standardizer constructor.
Returns
-------
Expand All @@ -441,20 +457,23 @@ def _get_standardizer(self, index):
ext_idx = row["ext_idx"]
if self._standardizers[std_idx] is None:
std_cls = Standardizer.registry[row["std_name"]]
self._standardizers[std_idx] = std_cls(row["location"])
self._standardizers[std_idx] = std_cls(**kwargs, **row)

# maybe a clever dataclass to shortcut the idx lookups on the user end?
return {"std": self._standardizers[std_idx],
"ext": self.data[index]["ext_idx"]}

def get_standardizers(self, idxs):
def get_standardizers(self, idxs, **kwargs):
""" Get the standardizers used to extract metadata of the selected
rows.
Parameters
----------
idx : `int` or `iterable`
Index of the row for which to retrieve the Standardizer.
**kwargs : `dict`
Keyword arguments are passed onto the constructors of the retrieved
Standardizer.
Returns
-------
Expand All @@ -463,19 +482,13 @@ def get_standardizers(self, idxs):
the extension (``ext``) that maps to the given metadata row index.
"""
if isinstance(idxs, int):
return self._get_standardizer(idxs)
return [self.get_standardizer(idxs, **kwargs), ]
else:
return [self._get_standardizer(idx) for idx in idxs]
return [self.get_standardizer(idx, **kwargs) for idx in idxs]

########################
# FUNCTIONALITY (object operations, transformative functionality)
########################
def toImageStack(self):
pass

def plot_onsky(self):
pass

def write(self, *args, format=None, serialize_method=None, **kwargs):
tmpdata = self.data.copy()

Expand Down Expand Up @@ -522,3 +535,193 @@ def get_duration(self):
"""
# maybe timespan?
return self.data["mjd"][-1] - self.data["mjd"][0]

def toImageStack(self):
"""Return an `~kbmod.search.image_stack` object for processing with
KBMOD.
Returns
-------
imageStack : `~kbmod.search.image_stack`
Image stack for processing with KBMOD.
"""
# unpack the layred image list to flatten the array
# this is so stupidly costly because we have an internal array
# representation that doesn't interface with numpy via ndarray it makes
# a copy every time
layeredImages = [img for std in self._standardizers for img in std.toLayeredImage()]
return image_stack(layeredImages)

def _calc_suggested_angle(self, wcs, center_pixel=(1000, 2000), step=12):
"""Projects an unit-vector parallel with the ecliptic onto the image
and calculates the angle of the projected unit-vector in the pixel
space.
Parameters
----------
wcs : ``astropy.wcs.WCS``
World Coordinate System object.
center_pixel : tuple, array-like
Pixel coordinates of image center.
step : ``float`` or ``int``
Size of step, in arcseconds, used to find the pixel coordinates of
the second pixel in the image parallel to the ecliptic.
Returns
-------
suggested_angle : ``float``
Angle the projected unit-vector parallel to the ecliptic
closes with the image axes. Used to transform the specified
search angles, with respect to the ecliptic, to search angles
within the image.
Note
----
It is not neccessary to calculate this angle for each image in an
image set if they have all been warped to a common WCS.
See Also
--------
run_search.do_gpu_search
"""
# pick a starting pixel approximately near the center of the image
# convert it to ecliptic coordinates
start_pixel = np.array(center_pixel)
start_pixel_coord = SkyCoord.from_pixel(start_pixel[0], start_pixel[1], wcs)
start_ecliptic_coord = start_pixel_coord.geocentrictrueecliptic

# pick a guess pixel by moving parallel to the ecliptic
# convert it to pixel coordinates for the given WCS
guess_ecliptic_coord = SkyCoord(
start_ecliptic_coord.lon + step * u.arcsec,
start_ecliptic_coord.lat,
frame="geocentrictrueecliptic",
)
guess_pixel_coord = guess_ecliptic_coord.to_pixel(wcs)

# calculate the distance, in pixel coordinates, between the guess and
# the start pixel. Calculate the angle that represents in the image.
x_dist, y_dist = np.array(guess_pixel_coord) - start_pixel
return np.arctan2(y_dist, x_dist)

def run(self, config):
"""Run KBMOD on the images in collection.
Parameters
----------
config : `~kbmod.configuration.KBMODConfig`
Processing configuration
Returns
-------
results : `kbmod.results.ResultList`
KBMOD search results.
Notes
-----
Requires WCS.
"""
imageStack = self.toImageStack()

# Compute the ecliptic angle for the images. Assume they are all the
# same size? Technically that is currently a requirement, although it's
# not explicit (can this be in C++ code?)
center_pixel = (imageStack.get_width()/2, imageStack.get_height()/2)
suggested_angle = self._calc_suggested_angle(self.wcs[0], center_pixel)

# Set up the post processing data structure.
kb_post_process = PostProcess(config, self.data["mjd"].data)

# Perform the actual search.
search = stack_search(imageStack)
# search, search_params = self.do_gpu_search(search, img_info,
# suggested_angle, kb_post_process)
# not sure why these were separated, I guess it made it look neater?
# definitely doesn't feel like everything is in place if there are so
# many ifs for a config - feels like that should be a config job?
# Anyhow, I'll be lazy and just unravel this here.
search_params = {}

# Run the grid search
# Set min and max values for angle and velocity
if config["average_angle"] == None:
average_angle = suggested_angle
else:
average_angle = config["average_angle"]
ang_min = average_angle - config["ang_arr"][0]
ang_max = average_angle + config["ang_arr"][1]
vel_min = config["v_arr"][0]
vel_max = config["v_arr"][1]
search_params["ang_lims"] = [ang_min, ang_max]
search_params["vel_lims"] = [vel_min, vel_max]

# Set the search bounds.
if config["x_pixel_bounds"] and len(config["x_pixel_bounds"]) == 2:
search.set_start_bounds_x(config["x_pixel_bounds"][0], config["x_pixel_bounds"][1])
elif config["x_pixel_buffer"] and config["x_pixel_buffer"] > 0:
width = search.get_image_stack().get_width()
search.set_start_bounds_x(-config["x_pixel_buffer"], width + config["x_pixel_buffer"])

if config["y_pixel_bounds"] and len(config["y_pixel_bounds"]) == 2:
search.set_start_bounds_y(config["y_pixel_bounds"][0], config["y_pixel_bounds"][1])
elif config["y_pixel_buffer"] and config["y_pixel_buffer"] > 0:
height = search.get_image_stack().get_height()
search.set_start_bounds_y(-config["y_pixel_buffer"], height + config["y_pixel_buffer"])

# If we are using barycentric corrections, compute the parameters and
# enable it in the search function. This can't be not-none atm because
# I hadn't copied bary_corr over....
if config["bary_dist"] is not None:
bary_corr = self._calc_barycentric_corr(img_info, config["bary_dist"])
# print average barycentric velocity for debugging

mjd_range = img_info.get_duration()
bary_vx = bary_corr[-1, 0] / mjd_range
bary_vy = bary_corr[-1, 3] / mjd_range
bary_v = np.sqrt(bary_vx * bary_vx + bary_vy * bary_vy)
bary_ang = np.arctan2(bary_vy, bary_vx)
print("Average Velocity from Barycentric Correction", bary_v, "pix/day", bary_ang, "angle")
search.enable_corr(bary_corr.flatten())

search_start = time.time()
print("Starting Search")
print("---------------------------------------")
param_headers = (
"Ecliptic Angle",
"Min. Search Angle",
"Max Search Angle",
"Min Velocity",
"Max Velocity",
)
param_values = (suggested_angle, *search_params["ang_lims"], *search_params["vel_lims"])
for header, val in zip(param_headers, param_values):
print("%s = %.4f" % (header, val))

# If we are using gpu_filtering, enable it and set the parameters.
if config["gpu_filter"]:
print("Using in-line GPU sigmaG filtering methods", flush=True)
coeff = post_process._find_sigmaG_coeff(config["sigmaG_lims"])
search.enable_gpu_sigmag_filter(
np.array(config["sigmaG_lims"]) / 100.0,
coeff,
config["lh_level"],
)

# If we are using an encoded image representation on GPU, enable it and
# set the parameters.
if config["encode_psi_bytes"] > 0 or config["encode_phi_bytes"] > 0:
search.enable_gpu_encoding(config["encode_psi_bytes"], config["encode_phi_bytes"])

# Enable debugging.
if config["debug"]:
search.set_debug(config["debug"])

search.search(
int(config["ang_arr"][2]),
int(config["v_arr"][2]),
*search_params["ang_lims"],
*search_params["vel_lims"],
int(config["num_obs"]),
)
print("Search finished in {0:.3f}s".format(time.time() - search_start), flush=True)
return search, search_params
Loading

0 comments on commit 7332543

Please sign in to comment.