Skip to content

Commit

Permalink
Reflection filter update (#1043)
Browse files Browse the repository at this point in the history
* Implementation of ML reflection filters

* Added random seed to IsolationForest model for reproducibility

* Addtional labeling for the diagnostic plots.

* Consolidated removed refl/expt counts

* Bug fix

* Added reflection intensity diagnostic plots

* Clean clutter

* Added Python 2.9 syntax exception

* Forgot to include the last commit in the pull request

* Implementation of ML reflection filters

* Added random seed to IsolationForest model for reproducibility

* Consolidated removed refl/expt counts

* Bug fix

* Forgot to include the last commit in the pull request
  • Loading branch information
dwmoreau authored Feb 6, 2025
1 parent b909d12 commit cd8183f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 69 deletions.
110 changes: 55 additions & 55 deletions xfel/merging/application/filter/reflection_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def __repr__(self):
def validate(self):
filter_by_significance = 'significance_filter' in self.params.select.algorithm
filter_by_isolation_forest = 'isolation_forest' in self.params.select.algorithm
filter_by_local_outlier_factor = 'local_outlier_factor' in self.params.select.algorithm
if filter_by_isolation_forest or filter_by_local_outlier_factor:
if filter_by_isolation_forest:
check0 = self.params.select.reflection_filter.tail_percentile > 0
check1 = self.params.select.reflection_filter.tail_percentile < 1
assert check0 and check1, \
Expand All @@ -36,16 +35,18 @@ def validate(self):
assert check0 and check1 and check2 and check3, \
'contamination must be between 0 and 1'
if filter_by_isolation_forest:
check0 = self.params.select.reflection_filter.isolation_forest.sampling_fraction > 0
check1 = self.params.select.reflection_filter.isolation_forest.sampling_fraction < 1
check0 = self.params.select.reflection_filter.sampling_fraction > 0
check1 = self.params.select.reflection_filter.sampling_fraction < 1
assert check0 and check1, \
'sampling_fraction must be between 0 and 1'
if filter_by_isolation_forest and filter_by_local_outlier_factor:
assert False, \
'Please only select one algorithm for outlier removal'

def plot_reflections(self, experiments, reflections, tag):
correct_info = 'miller_index' in reflections.keys()
if correct_info == False:
reflections['miller_index'] = reflections['miller_index_asymmetric']
q2_rank = 1 / reflections.compute_d(experiments).as_numpy_array()**2
if correct_info == False:
del reflections['miller_index']
intensity_rank = reflections['intensity.sum.value'].as_numpy_array()
q2 = self.mpi_helper.comm.gather(q2_rank, root=0)
intensity = self.mpi_helper.comm.gather(intensity_rank, root=0)
Expand All @@ -56,9 +57,14 @@ def plot_reflections(self, experiments, reflections, tag):
intensity = np.concatenate(intensity)
fig, axes = plt.subplots(1, 1, figsize=(8, 3))
axes.scatter(q2, intensity, s=1, color=[0, 0, 0], marker='.')
axes.set_xlabel('$q^2$ = 1/$d^2$ (1/$\mathrm{\AA^2}$)')
axes.set_ylabel('Intensity')
axes.set_title(tag)
axes.set_xlabel('Resolution ($\mathrm{\AA}$)')
xticks = axes.get_xticks()
xticks = xticks[xticks > 0]
xticklabels = [f'{l:0.2f}' for l in 1 / np.sqrt(xticks)]
axes.set_xticks(xticks)
axes.set_xticklabels(xticklabels)
fig.tight_layout()
fig.savefig(os.path.join(
self.params.output.output_dir,
Expand All @@ -69,9 +75,8 @@ def plot_reflections(self, experiments, reflections, tag):
def run(self, experiments, reflections):
filter_by_significance = 'significance_filter' in self.params.select.algorithm
filter_by_isolation_forest = 'isolation_forest' in self.params.select.algorithm
filter_by_local_outlier_factor = 'local_outlier_factor' in self.params.select.algorithm
# only "unit_cell" "n_obs" and "resolution" algorithms are supported
if (not filter_by_significance) and (not filter_by_isolation_forest) and (not filter_by_local_outlier_factor):
if (not filter_by_significance) and (not filter_by_isolation_forest):
return experiments, reflections

n_reflections_initial = len(reflections)
Expand Down Expand Up @@ -100,13 +105,8 @@ def run(self, experiments, reflections):
filter_type = 'Isolation Forest'
if self.params.select.reflection_filter.do_diagnostics:
self.plot_reflections(experiments, reflections, 'After Isolation Forest')
elif filter_by_local_outlier_factor:
experiments, reflections = self.apply_local_outlier_factor(experiments, reflections)
filter_type = 'Local Outlier Factor'
if self.params.select.reflection_filter.do_diagnostics:
self.plot_reflections(experiments, reflections, 'After Local Outlier Factor')

if filter_by_isolation_forest or filter_by_local_outlier_factor:
if filter_by_isolation_forest:
removed_reflections_filter_rank = n_reflections_initial - len(reflections)
removed_experiments_filter_rank = n_experiments_initial - len(experiments)
if filter_by_significance:
Expand Down Expand Up @@ -234,7 +234,12 @@ def apply_significance_filter(self, experiments, reflections):
return new_experiments, new_reflections

def _common_initial(self, experiments, reflections):
correct_info = 'miller_index' in reflections.keys()
if correct_info == False:
reflections['miller_index'] = reflections['miller_index_asymmetric']
resolution = reflections.compute_d(experiments)
if correct_info == False:
del reflections['miller_index']
reflections['q2'] = 1 / resolution**2
q2_rank = reflections['q2'].as_numpy_array()
intensity_rank = reflections['intensity.sum.value'].as_numpy_array()
Expand Down Expand Up @@ -345,7 +350,7 @@ def _common_initial(self, experiments, reflections):
return reflections, upper_tail, lower_tail

def do_diagnostics(self, reflections, model_upper, upper_tail, model_lower, lower_tail):
def plot_outliers(I_normalized, q2, Y, tag):
def plot_outliers(I_normalized, q2, Y, tag, model):
inlier_indices = Y == 1
outlier_indices = Y == -1
fig, axes = plt.subplots(1, 1, figsize=(8, 3), sharex=True)
Expand All @@ -357,9 +362,31 @@ def plot_outliers(I_normalized, q2, Y, tag):
q2[outlier_indices], I_normalized[outlier_indices],
s=20, color=[0.8, 0, 0], marker='.', alpha=1, label='Outliers'
)
axes.set_xlabel('$q^2$ = 1/$d^2$ (1/$\mathrm{\AA^2}$)')
xx, yy = np.meshgrid(
np.linspace(I_normalized.min(), I_normalized.max(), 150),
np.linspace(q2.min(), q2.max(), 150)
)
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# https://github.com/matplotlib/matplotlib/issues/23303
# The label for the contour does not appear in the legend. Must add manually.
contour = axes.contour(yy, xx, Z, levels=[0], linewidths=2, colors="green")
contour_handle, _ = contour.legend_elements()
handles, labels = axes.get_legend_handles_labels()
handles += contour_handle
labels += ['Decision Boundary']
axes.set_xlabel('Resolution ($\mathrm{\AA}$)')
xticks = axes.get_xticks()
xticks = xticks[xticks > 0]
xticklabels = [f'{l:0.2f}' for l in 1 / np.sqrt(xticks)]
axes.set_xticks(xticks)
axes.set_xticklabels(xticklabels)
axes.set_ylabel('Normalized Intensity')
axes.legend(frameon=False, loc='upper right')
if tag == 'upper':
loc = 'upper right'
elif tag == 'lower':
loc = 'lower right'
axes.legend(handles, labels, loc=loc, frameon=False)
fig.tight_layout()
fig.savefig(os.path.join(
self.params.output.output_dir,
Expand All @@ -374,20 +401,20 @@ def plot_outliers(I_normalized, q2, Y, tag):
if self.mpi_helper.rank == 0:
import matplotlib.pyplot as plt
import os
plot_outliers(upper_tail[:, 0], upper_tail[:, 1], model_upper.predict(upper_tail), 'upper')
plot_outliers(lower_tail[:, 0], lower_tail[:, 1], model_lower.predict(lower_tail), 'lower')
plot_outliers(upper_tail[:, 0], upper_tail[:, 1], model_upper.predict(upper_tail), 'upper', model_upper)
plot_outliers(lower_tail[:, 0], lower_tail[:, 1], model_lower.predict(lower_tail), 'lower', model_lower)

fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True)
axes[0].scatter(
np.concatenate(q2), np.concatenate(intensity_normalized),
s=1, color=[0, 0, 0], marker='.', alpha=0.5
)
axes[1].scatter(
lower_tail[:, 1], lower_tail[:, 0],
upper_tail[:, 1], upper_tail[:, 0],
s=1, color=[0, 0, 0], marker='.', alpha=0.5
)
axes[2].scatter(
upper_tail[:, 1], upper_tail[:, 0],
lower_tail[:, 1], lower_tail[:, 0],
s=1, color=[0, 0, 0], marker='.', alpha=0.5
)
axes[2].set_xlabel('$q^2$ = 1/$d^2$ (1/$\mathrm{\AA^2}$)')
Expand All @@ -406,19 +433,21 @@ def apply_isolation_forest(self, experiments, reflections):

reflections, upper_tail, lower_tail = self._common_initial(experiments, reflections)
if self.mpi_helper.rank == 0:
sampling_fraction = self.params.select.reflection_filter.isolation_forest.sampling_fraction
sampling_fraction = self.params.select.reflection_filter.sampling_fraction
model_lower = IsolationForest(
n_estimators=self.params.select.reflection_filter.n_estimators,
contamination=self.params.select.reflection_filter.contamination_lower,
max_features=2,
max_samples=int(sampling_fraction*lower_tail.shape[0]),
random_state=self.params.select.reflection_filter.isolation_forest.random_seed
random_state=self.params.select.reflection_filter.random_seed
)
model_lower.fit(lower_tail)
model_upper = IsolationForest(
n_estimators=self.params.select.reflection_filter.n_estimators,
contamination=self.params.select.reflection_filter.contamination_upper,
max_features=2,
max_samples=int(sampling_fraction*upper_tail.shape[0]),
random_state=self.params.select.reflection_filter.isolation_forest.random_seed
random_state=self.params.select.reflection_filter.random_seed
)
model_upper.fit(upper_tail)
else:
Expand All @@ -433,35 +462,6 @@ def apply_isolation_forest(self, experiments, reflections):
self.logger.log_step_time("ISOLATION_FOREST", True)
return new_experiments, new_reflections

def apply_local_outlier_factor(self, experiments, reflections):
self.logger.log_step_time("LOCAL_OUTLIER_FACTOR")
from sklearn.neighbors import LocalOutlierFactor

reflections, upper_tail, lower_tail = self._common_initial(experiments, reflections)
if self.mpi_helper.rank == 0:
model_lower = LocalOutlierFactor(
contamination=self.params.select.reflection_filter.contamination_lower,
n_neighbors=self.params.select.reflection_filter.local_outlier_factor.n_neighbors,
novelty=True,
)
model_lower.fit(lower_tail)
model_upper = LocalOutlierFactor(
contamination=self.params.select.reflection_filter.contamination_upper,
n_neighbors=self.params.select.reflection_filter.local_outlier_factor.n_neighbors,
novelty=True
)
model_upper.fit(upper_tail)
else:
model_lower = None
model_upper = None
if self.params.select.reflection_filter.do_diagnostics:
self.do_diagnostics(reflections, model_upper, upper_tail, model_lower, lower_tail)
new_experiments, new_reflections = self._common_final(
experiments, reflections, model_lower, model_upper, 'local outlier factor'
)
self.logger.log_step_time("LOCAL_OUTLIER_FACTOR", True)
return new_experiments, new_reflections

def _common_final(self, experiments, reflections, model_lower, model_upper, filter_type):
model_lower = self.mpi_helper.comm.bcast(model_lower, root=0)
model_upper = self.mpi_helper.comm.bcast(model_upper, root=0)
Expand Down
24 changes: 10 additions & 14 deletions xfel/merging/application/phil/phil.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
.help = The select section accepts or rejects specified reflections
.help = refer to the filter section for filtering of whole experiments
{
algorithm = panel cspad_sensor significance_filter isolation_forest local_outlier_factor
algorithm = panel cspad_sensor significance_filter isolation_forest
.type = choice(multi=True)
cspad_sensor {
number = None
Expand Down Expand Up @@ -375,19 +375,15 @@
contamination_upper = 0.0001
.type = float
.help = Fraction of upper tail reflections that are outliers
local_outlier_factor {
n_neighbors = 200
.type = int
.help = Number of neighbors used to determine local density.
}
isolation_forest {
sampling_fraction = 0.05
.type = float
.help = Fraction of total dataset subsampled to train each decision tree.
random_seed = 0
.type = int
.help = seed for the random forest model
}
n_estimators = 1000
.type = int
.help = Number of decision trees in random forest model
sampling_fraction = 0.1
.type = float
.help = Fraction of total dataset subsampled to train each decision tree.
random_seed = 0
.type = int
.help = seed for the random forest model
}
}
"""
Expand Down

0 comments on commit cd8183f

Please sign in to comment.