Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move sigmaG filtering to its own file #394

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 12 additions & 128 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from .file_utils import *
from .filters.clustering_filters import DBSCANFilter
from .filters.stats_filters import *
from .filters.stats_filters import LHFilter, NumObsFilter
from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from .result_list import ResultList, ResultRow


Expand Down Expand Up @@ -67,6 +68,13 @@ def load_and_filter_results(
res_num = 0
total_count = 0

# Set up the clipped sigmaG filter.
if self.sigmaG_lims is not None:
bnds = self.sigmaG_lims
else:
bnds = [25, 75]
clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative)

print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
Expand Down Expand Up @@ -97,11 +105,12 @@ def load_and_filter_results(
batch_size = result_batch.num_results()
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
self.apply_clipped_sigmaG(result_batch)
apply_clipped_sigma_g(clipper, result_batch, self.num_cores)
result_batch.apply_filter(NumObsFilter(3))

# Apply the likelihood filter if one is provided.
if lh_level > 0.0:
result_batch.apply_filter(LHFilter(lh_level, None))
result_batch.apply_filter(NumObsFilter(3))

# Add the results to the final set.
keep.extend(result_batch)
Expand Down Expand Up @@ -132,109 +141,6 @@ def get_all_stamps(self, result_list, search, stamp_radius):
# ref to its private field. That's a fix for another time.
row.all_stamps = np.array([stamp.image for stamp in stamps])

def apply_clipped_sigmaG(self, result_list):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.

Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly
by the filtering.
"""
print("Applying Clipped-sigmaG Filtering")
start_time = time.time()

# Compute the coefficients for the filtering.
if self.coeff is None:
if self.sigmaG_lims is not None:
self.percentiles = self.sigmaG_lims
else:
self.percentiles = [25, 75]
self.coeff = find_sigmaG_coeff(self.percentiles)

if self.num_cores > 1:
zipped_curves = result_list.zip_phi_psi_idx()

keep_idx_results = []
print("Starting pooling...")
pool = mp.Pool(processes=self.num_cores)
keep_idx_results = pool.starmap_async(self._clipped_sigmaG, zipped_curves)
pool.close()
pool.join()
keep_idx_results = keep_idx_results.get()

for i, res in enumerate(keep_idx_results):
result_list.results[i].filter_indices(res[1])
else:
for i, row in enumerate(result_list.results):
single_res = self._clipped_sigmaG(row.psi_curve, row.phi_curve, i)
row.filter_indices(single_res[1])

end_time = time.time()
time_elapsed = end_time - start_time
print("{:.2f}s elapsed".format(time_elapsed))
print("Completed filtering.", flush=True)
print("---------------------------------------")

def _clipped_sigmaG(self, psi_curve, phi_curve, index, n_sigma=2):
"""This function applies a clipped median filter to a set of likelihood
values. Points are eliminated if they are more than n_sigma*sigmaG away
from the median.

Parameters
----------
psi_curve : numpy array
A single Psi curve, likely from a `ResultRow`.
phi_curve : numpy array
A single Phi curve, likely from a `ResultRow`.
index : int
The index of the ResultRow being processed. Used track
multiprocessing.
n_sigma : int
The number of standard deviations away from the median that
the largest likelihood values (N=num_clipped) must be in order
to be eliminated.

Returns
-------
index : int
The index of the ResultRow being processed. Used track multiprocessing.
good_index: numpy array
The indices that pass the filtering for a given set of curves.
new_lh : float
The new maximum likelihood of the set of curves, after max_lh_index has
been applied.
"""
masked_phi = np.copy(phi_curve)
masked_phi[masked_phi == 0] = 1e9

lh = psi_curve / np.sqrt(masked_phi)
good_index = self._exclude_outliers(lh, n_sigma)
if len(good_index) == 0:
new_lh = 0
good_index = []
else:
new_lh = kb.calculate_likelihood_psi_phi(psi_curve[good_index], phi_curve[good_index])
return (index, good_index, new_lh)

def _exclude_outliers(self, lh, n_sigma):
if self.clip_negative:
lower_per, median, upper_per = np.percentile(
lh[lh > 0], [self.percentiles[0], 50, self.percentiles[1]]
)
sigmaG = self.coeff * (upper_per - lower_per)
nSigmaG = n_sigma * sigmaG
good_index = np.where(
np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))
)[0]
else:
lower_per, median, upper_per = np.percentile(lh, [self.percentiles[0], 50, self.percentiles[1]])
sigmaG = self.coeff * (upper_per - lower_per)
nSigmaG = n_sigma * sigmaG
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]
return good_index

def apply_stamp_filter(
self,
result_list,
Expand Down Expand Up @@ -382,25 +288,3 @@ def apply_clustering(self, result_list, cluster_params):
cluster_params["mjd"],
)
result_list.apply_batch_filter(f)


# Additional math utilities -----------


def invert_Gaussian_CDF(z):
if z < 0.5:
sign = -1
else:
sign = 1
x = sign * np.sqrt(2) * erfinv(sign * (2 * z - 1)) # mpmath.erfinv(sign * (2 * z - 1))
return float(x)


def find_sigmaG_coeff(percentiles):
z1 = percentiles[0] / 100
z2 = percentiles[1] / 100

x1 = invert_Gaussian_CDF(z1)
x2 = invert_Gaussian_CDF(z2)
coeff = 1 / (x2 - x1)
return coeff
123 changes: 123 additions & 0 deletions src/kbmod/filters/sigma_g_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Functions to help with the SigmaG clipping.

For more details see:
Sifting Through the Static: Moving Objectg Detection in Difference Images
by Smotherman et. al. 2021
"""

import multiprocessing as mp
import numpy as np
from scipy.special import erfinv

from kbmod.result_list import ResultList, ResultRow


class SigmaGClipping:
"""This class contains the basic information for performing SigmaG clipping.

Attributes
----------
low_bnd : `float`
The lower bound of the interval to use to estimate the standard deviation.
high_bnd : `float`
The upper bound of the interval to use to estimate the standard deviation.
n_sigma : `float`
The number of standard deviations to use for the bound.
clip_negative : `bool`
A Boolean indicating whether to use negative values when computing
standard deviation.
coeff : `float`
The precomputed coefficient based on the given bounds.
"""

def __init__(self, low_bnd=25, high_bnd=75, n_sigma=2, clip_negative=False):
if low_bnd > high_bnd or low_bnd <= 0.0 or high_bnd >= 100.0:
raise ValueError(f"Invalid bounds [{low_bnd}, {high_bnd}]")
if n_sigma <= 0.0:
raise ValueError(f"Invalid n_sigma {n_sigma}")

self.low_bnd = low_bnd
self.high_bnd = high_bnd
self.clip_negative = clip_negative
self.n_sigma = n_sigma
self.coeff = SigmaGClipping.find_sigma_g_coeff(low_bnd, high_bnd)

@staticmethod
def find_sigma_g_coeff(low_bnd, high_bnd):
x1 = SigmaGClipping.invert_gauss_cdf(low_bnd / 100.0)
x2 = SigmaGClipping.invert_gauss_cdf(high_bnd / 100.0)
return 1 / (x2 - x1)

@staticmethod
def invert_gauss_cdf(z):
if z < 0.5:
sign = -1
else:
sign = 1
x = sign * np.sqrt(2) * erfinv(sign * (2 * z - 1))
return float(x)

def compute_clipped_sigma_g(self, lh):
"""Compute the SigmaG clipping on the given likelihood curve.
Points are eliminated if they are more than n_sigma*sigmaG away from the median.

Parameters
----------
lh : numpy array
A single likelihood curve.

Returns
-------
good_index: numpy array
The indices that pass the filtering for a given set of curves.
"""
if self.clip_negative:
lower_per, median, upper_per = np.percentile(lh[lh > 0], [self.low_bnd, 50, self.high_bnd])
else:
lower_per, median, upper_per = np.percentile(lh, [self.low_bnd, 50, self.high_bnd])

delta = max(upper_per - lower_per, 1e-8)
sigmaG = self.coeff * delta
nSigmaG = self.n_sigma * sigmaG

# Its unclear why we only filter zeros for one of the two cases, but leaving the logic in
# to stay consistent with the original code.
if self.clip_negative:
good_index = np.where(
np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))
)[0]
else:
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]

return good_index


def apply_clipped_sigma_g(params, result_list, num_threads=1):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.

Parameters
----------
params : `SigmaGClipping`
The object to apply the SigmaG clipping.
result_list : `ResultList`
The values from trajectories. This data gets modified directly by the filtering.
num_threads : `int`
The number of threads to use.
"""
if num_threads > 1:
lh_list = [[row.likelihood_curve] for row in result_list.results]

keep_idx_results = []
pool = mp.Pool(processes=num_threads)
keep_idx_results = pool.starmap_async(params.compute_clipped_sigma_g, lh_list)
pool.close()
pool.join()
keep_idx_results = keep_idx_results.get()

for i, res in enumerate(keep_idx_results):
result_list.results[i].filter_indices(res)
else:
for i, row in enumerate(result_list.results):
single_res = params.compute_clipped_sigma_g(row.likelihood_curve)
row.filter_indices(single_res)
45 changes: 21 additions & 24 deletions src/kbmod/result_list.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import math
import multiprocessing as mp
import os.path as ospath

import numpy as np
import os.path as ospath

from kbmod.file_utils import *

Expand Down Expand Up @@ -52,18 +51,16 @@ def light_curve(self):

Returns
-------
lc : list
The likelihood curve. This is an empty list if either
lc : `numpy.ndarray`
The light curve. This is an empty array if either
psi or phi are not set.
"""
if self.psi_curve is None or self.phi_curve is None:
return []
return np.array([])

num_elements = len(self.psi_curve)
lc = [0.0] * num_elements
for i in range(num_elements):
if self.phi_curve[i] != 0.0:
lc[i] = self.psi_curve[i] / self.phi_curve[i]
masked_phi = np.copy(self.phi_curve)
masked_phi[masked_phi == 0] = 1e12
lc = np.divide(self.psi_curve, masked_phi)
return lc

@property
Expand All @@ -72,20 +69,16 @@ def likelihood_curve(self):

Returns
-------
lh : list
The likelihood curve. This is an empty list if either
lh : `numpy.ndarray`
The likelihood curve. This is an empty array if either
psi or phi are not set.
"""
if self.psi_curve is None:
raise ValueError("Psi curve is None")
if self.phi_curve is None:
raise ValueError("Phi curve is None")

num_elements = len(self.psi_curve)
lh = [0.0] * num_elements
for i in range(num_elements):
if self.phi_curve[i] > 0.0:
lh[i] = self.psi_curve[i] / math.sqrt(self.phi_curve[i])
if self.psi_curve is None or self.phi_curve is None:
return np.array([])

masked_phi = np.copy(self.phi_curve)
masked_phi[masked_phi == 0] = 1e12
lh = np.divide(self.psi_curve, np.sqrt(masked_phi))
return lh

def valid_indices_as_booleans(self):
Expand Down Expand Up @@ -172,8 +165,8 @@ def _update_likelihood(self):
self.trajectory.lh = 0.0
self.trajectory.flux = 0.0
else:
self.final_likelihood = psi_sum / math.sqrt(phi_sum)
self.trajectory.lh = psi_sum / math.sqrt(phi_sum)
self.final_likelihood = psi_sum / np.sqrt(phi_sum)
self.trajectory.lh = psi_sum / np.sqrt(phi_sum)
self.trajectory.flux = psi_sum / phi_sum


Expand Down Expand Up @@ -208,6 +201,10 @@ def num_results(self):
"""
return len(self.results)

def __len__(self):
"""Return the number of results in the list."""
return len(self.results)

def clear(self):
"""Clear the list of results."""
self.results.clear()
Expand Down
Loading
Loading