Skip to content

Commit

Permalink
Merge pull request #437 from dirac-institute/output_table
Browse files Browse the repository at this point in the history
Add option to output results to single table
  • Loading branch information
jeremykubica authored Jan 19, 2024
2 parents 31fbc57 + 4336647 commit 86a4456
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 32 deletions.
10 changes: 9 additions & 1 deletion docs/source/user_manual/search_params.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ This document serves to provide a quick overview of the existing parameters and
| | | directory with multiple FITS files |
| | | (one for each exposure). |
+------------------------+-----------------------------+----------------------------------------+
| ``ind_output_files`` | True | Output results to a series of |
| | | individual files (legacy format) |
+------------------------+-----------------------------+----------------------------------------+
| ``known_obj_obs`` | 3 | The minimum number of observations |
| | | needed to count a known object match. |
+------------------------+-----------------------------+----------------------------------------+
Expand Down Expand Up @@ -145,7 +148,12 @@ This document serves to provide a quick overview of the existing parameters and
| | | mask. See :ref:`Masking`. |
+------------------------+-----------------------------+----------------------------------------+
| ``res_filepath`` | None | The path of the directory in which to |
| | | store the results files. |
| | | store the individual results files. |
+------------------------+-----------------------------+----------------------------------------+
| ``result_filename`` | None | Full filename and path for a single |
| | | tabular result saves as ecsv. |
| | | Can be use used in addition to |
| | | outputting individual result files. |
+------------------------+-----------------------------+----------------------------------------+
| ``sigmaG_lims`` | [25, 75] | The percentiles to use in sigmaG |
| | | filtering, if |
Expand Down
2 changes: 2 additions & 0 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
"encode_num_bytes": -1,
"flag_keys": default_flag_keys,
"gpu_filter": False,
"ind_output_files": True,
"im_filepath": None,
"known_obj_obs": 3,
"known_obj_thresh": None,
Expand All @@ -72,6 +73,7 @@ def __init__(self):
"psf_file": None,
"repeated_flag_keys": default_repeated_flag_keys,
"res_filepath": None,
"result_filename": None,
"sigmaG_lims": [25, 75],
"stamp_radius": 10,
"stamp_type": "sum",
Expand Down
234 changes: 233 additions & 1 deletion src/kbmod/result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,44 @@
import multiprocessing as mp
import numpy as np
import os.path as ospath
from pathlib import Path

from astropy.table import Table
from yaml import dump, safe_load

from kbmod.file_utils import *
from kbmod.trajectory_utils import (
make_trajectory,
trajectory_from_yaml,
trajectory_predict_skypos,
trajectory_to_yaml,
)


def _check_optional_allclose(arr1, arr2):
"""Check whether toward optional numpy arrays have the same information.
Parameters
----------
arr1 : `numpy.ndarray` or `None`
The first array.
arr1 : `numpy.ndarray` or `None`
The second array.
Returns
-------
result : `bool`
Indicates whether the arrays are the same.
"""
if arr1 is None and arr2 is None:
return True
if arr1 is not None and arr2 is None:
return False
if arr1 is None and arr2 is not None:
return False
return np.allclose(arr1, arr2)


class ResultRow:
"""This class stores a collection of related data from a single kbmod result.
In order to maintain a consistent internal state, the class uses private variables
Expand Down Expand Up @@ -86,6 +112,104 @@ def __init__(self, trj, num_times):
self.trajectory = trj
self._valid_indices = [i for i in range(num_times)]

@classmethod
def from_table_row(cls, data, num_times=None):
"""Create a ResultRow object directly from an AstroPy Table row.
Parameters
----------
data : 'astropy.table.row.Row'
The incoming row.
all_times : `int`, optional
The number of total times in the data. If ``None`` tries
to extract from a "num_times" or "all_times" column.
Raises
------
KeyError if a column is missing.
"""
if num_times is None:
if "num_times" in data.columns:
num_times = data["num_times"]
elif "all_times" in data.columns:
num_times = len(data["all_times"])
else:
raise KeyError("Number of times is not specified.")

# Create the Trajectory object from the correct fields.
trj = make_trajectory(
data["trajectory_x"],
data["trajectory_y"],
data["trajectory_vx"],
data["trajectory_vy"],
data["flux"],
data["likelihood"],
data["obs_count"],
)

# Manually fill in all the rest of the values. We let the stamp related columns
# be empty to save space.
row = ResultRow(trj, num_times)
row._final_likelihood = data["likelihood"]
row._phi_curve = data["phi_curve"]
row.pred_dec = data["pred_dec"]
row.pred_ra = data["pred_ra"]
row._psi_curve = data["psi_curve"]
row._valid_indices = data["valid_indices"]

if "all_stamps" in data.columns:
row.all_stamps = data["all_stamps"]
else:
row.all_stamps = None

if "stamp" in data.columns:
row.stamp = data["stamp"]
else:
row.stamp = None

return row

def __eq__(self, other):
"""Test if two result rows are equal."""
if not isinstance(other, ResultRow):
return False

# Check the attributes of the trajectory first.
if (
self.trajectory.x != other.trajectory.x
or self.trajectory.y != other.trajectory.y
or self.trajectory.vx != other.trajectory.vx
or self.trajectory.vy != other.trajectory.vy
or self.trajectory.lh != other.trajectory.lh
or self.trajectory.flux != other.trajectory.flux
or self.trajectory.obs_count != other.trajectory.obs_count
):
return False

# Check the simple attributes.
if not self._num_times == other._num_times:
return False
if not self._final_likelihood == other._final_likelihood:
return False

# Check the curves and stamps.
if not _check_optional_allclose(self.all_stamps, other.all_stamps):
return False
if not _check_optional_allclose(self._phi_curve, other._phi_curve):
return False
if not _check_optional_allclose(self._psi_curve, other._psi_curve):
return False
if not _check_optional_allclose(self.stamp, other.stamp):
return False
if not _check_optional_allclose(self._valid_indices, other._valid_indices):
return False
if not _check_optional_allclose(self.pred_dec, other.pred_dec):
return False
if not _check_optional_allclose(self.pred_ra, other.pred_ra):
return False

return True

@property
def final_likelihood(self):
return self._final_likelihood
Expand Down Expand Up @@ -399,6 +523,59 @@ def from_yaml(cls, yaml_str):
result_list.filtered[key] = [ResultRow.from_yaml(row) for row in yaml_dict["filtered"][key]]
return result_list

@classmethod
def from_table(self, data, all_times=None, track_filtered=False):
"""Extract the ResultList from an astropy Table.
Parameters
----------
data : `astropy.table.Table`
The input data.
all_times : `List` or `numpy.ndarray` or None
The list of all time stamps. Must either be set or there
must be an all_times column in the Table.
track_filtered : `bool`
Indicates whether the ResultList should track future filtered points.
Raises
------
KeyError if any columns are missing or if is ``all_times`` is None and there
is no all_times column in the data.
"""
# Check that we have some list of time stamps and place it in all_times.
if all_times is None:
if "all_times" not in data.columns:
raise KeyError(f"No time stamps provided.")
else:
all_times = data["all_times"][0]
num_times = len(all_times)

result_list = ResultList(all_times, track_filtered)
for i in range(len(data)):
row = ResultRow.from_table_row(data[i], num_times)
result_list.append_result(row)
return result_list

@classmethod
def read_table(self, filename):
"""Read the ResultList from a table file.
Parameters
----------
filename : `str`
The name of the file to load.
Raises
------
FileNotFoundError if the file is not found.
KeyError if any of the columns are missing.
"""
if not Path(filename).is_file():
raise FileNotFoundError
data = Table.read(filename)
return ResultList.from_table(data)

def num_results(self):
"""Return the number of results in the list.
Expand All @@ -413,6 +590,37 @@ def __len__(self):
"""Return the number of results in the list."""
return len(self.results)

def __eq__(self, other):
"""Test if two ResultLists are equal. Includes both ordering and values."""
if not isinstance(other, ResultList):
return False
if not np.allclose(self._all_times, other._all_times):
return False
if self.track_filtered != other.track_filtered:
return False

num_results = len(self.results)
if num_results != len(other.results):
return False
for i in range(num_results):
if self.results[i] != other.results[i]:
return False

if len(self.filtered) != len(other.filtered):
return False
for key in self.filtered.keys():
if key not in other.filtered:
return False

num_filtered = len(self.filtered[key])
if num_filtered != len(other.filtered):
return False
for i in range(num_filtered):
if self.filtered[key][i] != other.filtered[key][i]:
return False

return True

def clear(self):
"""Clear the list of results."""
self.results.clear()
Expand Down Expand Up @@ -648,14 +856,16 @@ def revert_filter(self, label=None):

return self

def to_table(self, filtered_label=None):
def to_table(self, filtered_label=None, append_times=False):
"""Extract the results into an astropy table.
Parameters
----------
filtered_label : `str`, optional
The filtering label to extract. If None then extracts
the unfiltered rows. (default=None)
append_times : `bool`
Append the list of all times as a column in the data.
Returns
-------
Expand Down Expand Up @@ -690,12 +900,34 @@ def to_table(self, filtered_label=None):
"pred_ra": [],
"pred_dec": [],
}
if append_times:
table_dict["all_times"] = []

# Use a (slow) linear scan to do the transformation.
for row in list_ref:
row.append_to_dict(table_dict, True)
if append_times:
table_dict["all_times"].append(self._all_times)

return Table(table_dict)

def write_table(self, filename, overwrite=True):
"""Write the unfiltered results to a single (ecsv) file.
Parameter
---------
filename : `str`
The name of the result file.
overwrite : `bool`
Overwrite the file if it already exists.
"""
table_version = self.to_table(append_times=True)

# Drop the all stamps column as this is often too large to write in a CSV entry.
table_version.remove_column("all_stamps")

table_version.write(filename, overwrite=True)

def to_yaml(self, serialize_filtered=False):
"""Serialize the ResultList as a YAML string.
Expand Down
4 changes: 3 additions & 1 deletion src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,13 @@ def run_search(self, config, stack):

# Save the results and the configuration information used.
print(f"Found {keep.num_results()} potential trajectories.")
if config["res_filepath"] is not None:
if config["res_filepath"] is not None and config["ind_output_files"]:
keep.save_to_files(config["res_filepath"], config["output_suffix"])

config_filename = os.path.join(config["res_filepath"], f"config_{config['output_suffix']}.yml")
config.to_file(config_filename, overwrite=True)
if config["result_filename"] is not None:
keep.write_table(config["result_filename"])

end = time.time()
print("Time taken for patch: ", end - start)
Expand Down
Loading

0 comments on commit 86a4456

Please sign in to comment.