Skip to content

Commit

Permalink
Merge pull request #371 from dirac-institute/analysis_utils_refactor
Browse files Browse the repository at this point in the history
Analysis utils refactor
  • Loading branch information
maxwest-uw authored Oct 12, 2023
2 parents 5acc9eb + 6b90fae commit 1c468ad
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 184 deletions.
1 change: 1 addition & 0 deletions src/kbmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from . import (
analysis,
analysis_utils,
data_interface,
file_utils,
filters,
jointfit_functions,
Expand Down
139 changes: 1 addition & 138 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import os
import time

from astropy.io import fits
from astropy.wcs import WCS
import numpy as np
from scipy.special import erfinv # import mpmath
from scipy.special import erfinv

import kbmod.search as kb

Expand All @@ -15,141 +13,6 @@
from .result_list import ResultList, ResultRow


class Interface:
"""This class manages is responsible for loading in data from .fits
and auxiliary files.
"""

def __init__(self):
return

def load_images(
self,
im_filepath,
time_file,
psf_file,
mjd_lims,
default_psf,
verbose=False,
):
"""This function loads images and ingests them into a search object.
Parameters
----------
im_filepath : string
Image file path from which to load images.
time_file : string
File name containing image times.
psf_file : string
File name containing the image-specific PSFs.
If set to None the code will use the provided default psf for
all images.
mjd_lims : list of ints
Optional MJD limits on the images to search.
default_psf : `psf`
The default PSF in case no image-specific PSF is provided.
verbose : bool
Use verbose output (mainly for debugging).
Returns
-------
stack : `kbmod.ImageStack`
The stack of images loaded.
wcs_list : `list`
A list of `astropy.wcs.WCS` objects for each image.
visit_times : `list`
A list of MJD times.
"""
print("---------------------------------------")
print("Loading Images")
print("---------------------------------------")

# Load a mapping from visit numbers to the visit times. This dictionary stays
# empty if no time file is specified.
image_time_dict = FileUtils.load_time_dictionary(time_file)
if verbose:
print(f"Loaded {len(image_time_dict)} time stamps.")

# Load a mapping from visit numbers to PSFs. This dictionary stays
# empty if no time file is specified.
image_psf_dict = FileUtils.load_psf_dictionary(psf_file)
if verbose:
print(f"Loaded {len(image_psf_dict)} image PSFs stamps.")

# Retrieve the list of visits (file names) in the data directory.
patch_visits = sorted(os.listdir(im_filepath))

# Load the images themselves.
images = []
visit_times = []
wcs_list = []
for visit_file in np.sort(patch_visits):
# Skip non-fits files.
if not ".fits" in visit_file:
if verbose:
print(f"Skipping non-FITS file {visit_file}")
continue

# Compute the full file path for loading.
full_file_path = os.path.join(im_filepath, visit_file)

# Try loading information from the FITS header.
visit_id = None
with fits.open(full_file_path) as hdu_list:
curr_wcs = WCS(hdu_list[1].header)

# If the visit ID is in header (using Rubin tags), use for the visit ID.
# Otherwise extract it from the filename.
if "IDNUM" in hdu_list[0].header:
visit_id = str(hdu_list[0].header["IDNUM"])
else:
name = os.path.split(full_file_path)[-1]
visit_id = FileUtils.visit_from_file_name(name)

# Skip files without a valid visit ID.
if visit_id is None:
if verbose:
print(f"WARNING: Unable to extract visit ID for {visit_file}.")
continue

# Check if the image has a specific PSF.
psf = default_psf
if visit_id in image_psf_dict:
psf = kb.PSF(image_psf_dict[visit_id])

# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
img = kb.LayeredImage(full_file_path, psf)
time_stamp = img.get_obstime()

# Overload the header's time stamp if needed.
if visit_id in image_time_dict:
time_stamp = image_time_dict[visit_id]
img.set_obstime(time_stamp)

if time_stamp <= 0.0:
if verbose:
print(f"WARNING: No valid timestamp provided for {visit_file}.")
continue

# Check if we should filter the record based on the time bounds.
if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]):
if verbose:
print(f"Pruning file {visit_file} by timestamp={time_stamp}.")
continue

# Save image, time, and WCS information.
visit_times.append(time_stamp)
images.append(img)
wcs_list.append(curr_wcs)

print(f"Loaded {len(images)} images")
stack = kb.ImageStack(images)

return (stack, wcs_list, visit_times)


class PostProcess:
"""This class manages the post-processing utilities used to filter out and
otherwise remove false positives from the KBMOD search. This includes,
Expand Down
145 changes: 145 additions & 0 deletions src/kbmod/data_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os

from astropy.io import fits
from astropy.wcs import WCS
import numpy as np

import kbmod.search as kb

from .file_utils import *
from .filters.stats_filters import *


class Interface:
"""This class manages is responsible for loading in data from .fits
and auxiliary files.
"""

def __init__(self):
return

def load_images(
self,
im_filepath,
time_file,
psf_file,
mjd_lims,
default_psf,
verbose=False,
):
"""This function loads images and ingests them into a search object.
Parameters
----------
im_filepath : string
Image file path from which to load images.
time_file : string
File name containing image times.
psf_file : string
File name containing the image-specific PSFs.
If set to None the code will use the provided default psf for
all images.
mjd_lims : list of ints
Optional MJD limits on the images to search.
default_psf : `psf`
The default PSF in case no image-specific PSF is provided.
verbose : bool
Use verbose output (mainly for debugging).
Returns
-------
stack : `kbmod.ImageStack`
The stack of images loaded.
wcs_list : `list`
A list of `astropy.wcs.WCS` objects for each image.
visit_times : `list`
A list of MJD times.
"""
print("---------------------------------------")
print("Loading Images")
print("---------------------------------------")

# Load a mapping from visit numbers to the visit times. This dictionary stays
# empty if no time file is specified.
image_time_dict = FileUtils.load_time_dictionary(time_file)
if verbose:
print(f"Loaded {len(image_time_dict)} time stamps.")

# Load a mapping from visit numbers to PSFs. This dictionary stays
# empty if no time file is specified.
image_psf_dict = FileUtils.load_psf_dictionary(psf_file)
if verbose:
print(f"Loaded {len(image_psf_dict)} image PSFs stamps.")

# Retrieve the list of visits (file names) in the data directory.
patch_visits = sorted(os.listdir(im_filepath))

# Load the images themselves.
images = []
visit_times = []
wcs_list = []
for visit_file in np.sort(patch_visits):
# Skip non-fits files.
if not ".fits" in visit_file:
if verbose:
print(f"Skipping non-FITS file {visit_file}")
continue

# Compute the full file path for loading.
full_file_path = os.path.join(im_filepath, visit_file)

# Try loading information from the FITS header.
visit_id = None
with fits.open(full_file_path) as hdu_list:
curr_wcs = WCS(hdu_list[1].header)

# If the visit ID is in header (using Rubin tags), use for the visit ID.
# Otherwise extract it from the filename.
if "IDNUM" in hdu_list[0].header:
visit_id = str(hdu_list[0].header["IDNUM"])
else:
name = os.path.split(full_file_path)[-1]
visit_id = FileUtils.visit_from_file_name(name)

# Skip files without a valid visit ID.
if visit_id is None:
if verbose:
print(f"WARNING: Unable to extract visit ID for {visit_file}.")
continue

# Check if the image has a specific PSF.
psf = default_psf
if visit_id in image_psf_dict:
psf = kb.PSF(image_psf_dict[visit_id])

# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
img = kb.LayeredImage(full_file_path, psf)
time_stamp = img.get_obstime()

# Overload the header's time stamp if needed.
if visit_id in image_time_dict:
time_stamp = image_time_dict[visit_id]
img.set_obstime(time_stamp)

if time_stamp <= 0.0:
if verbose:
print(f"WARNING: No valid timestamp provided for {visit_file}.")
continue

# Check if we should filter the record based on the time bounds.
if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]):
if verbose:
print(f"Pruning file {visit_file} by timestamp={time_stamp}.")
continue

# Save image, time, and WCS information.
visit_times.append(time_stamp)
images.append(img)
wcs_list.append(curr_wcs)

print(f"Loaded {len(images)} images")
stack = kb.ImageStack(images)

return (stack, wcs_list, visit_times)
3 changes: 2 additions & 1 deletion src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import kbmod.search as kb

from .analysis_utils import Interface, PostProcess
from .analysis_utils import PostProcess
from .data_interface import Interface
from .configuration import SearchConfiguration
from .masking import (
BitVectorMasker,
Expand Down
46 changes: 1 addition & 45 deletions tests/test_analysis_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from kbmod.analysis_utils import *
from kbmod.data_interface import Interface
from kbmod.fake_data_creator import add_fake_object
from kbmod.result_list import *
from kbmod.search import *
Expand Down Expand Up @@ -376,51 +377,6 @@ def test_load_and_filter_results_lh(self):
self.assertEqual(results.results[0].trajectory.y, 30)
self.assertEqual(results.results[1].trajectory.y, 40)

def test_file_load_basic(self):
loader = Interface()
stack, wcs_list, mjds = loader.load_images(
get_absolute_data_path("fake_images"),
None,
None,
[0, 157130.2],
PSF(1.0),
verbose=False,
)
self.assertEqual(stack.img_count(), 4)

# Check that each image loaded corrected.
true_times = [57130.2, 57130.21, 57130.22, 57131.2]
for i in range(stack.img_count()):
img = stack.get_single_image(i)
self.assertEqual(img.get_width(), 64)
self.assertEqual(img.get_height(), 64)
self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005)
self.assertAlmostEqual(1.0, img.get_psf().get_std())

def test_file_load_extra(self):
p = PSF(1.0)

loader = Interface()
stack, wcs_list, mjds = loader.load_images(
get_absolute_data_path("fake_images"),
get_absolute_data_path("fake_times.dat"),
get_absolute_data_path("fake_psfs.dat"),
[0, 157130.2],
p,
verbose=False,
)
self.assertEqual(stack.img_count(), 4)

# Check that each image loaded corrected.
true_times = [57130.2, 57130.21, 57130.22, 57162.0]
psfs_std = [1.0, 1.0, 1.3, 1.0]
for i in range(stack.img_count()):
img = stack.get_single_image(i)
self.assertEqual(img.get_width(), 64)
self.assertEqual(img.get_height(), 64)
self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005)
self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std())


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 1c468ad

Please sign in to comment.