diff --git a/xfel/merging/application/errors/error_modifier_mm24.py b/xfel/merging/application/errors/error_modifier_mm24.py index f2ffa2b417..52d8825ff3 100644 --- a/xfel/merging/application/errors/error_modifier_mm24.py +++ b/xfel/merging/application/errors/error_modifier_mm24.py @@ -10,6 +10,7 @@ from xfel.merging.application.worker import worker from xfel.merging.application.reflection_table_utils import reflection_table_utils + class error_modifier_mm24(worker): def __init__(self, params, mpi_helper=None, mpi_logger=None): super(error_modifier_mm24, self).__init__(params=params, mpi_helper=mpi_helper, mpi_logger=mpi_logger) @@ -24,10 +25,13 @@ def __init__(self, params, mpi_helper=None, mpi_logger=None): self.tuning_param = self.params.merging.error.mm24.tuning_param self.number_of_intensity_bins = self.params.merging.error.mm24.number_of_intensity_bins - if self.params.merging.error.mm24.cc_after_pr: - self.cc_key = 'correlation_after_post' + if self.params.merging.error.mm24.constant_sadd: + self.cc_key = None else: - self.cc_key = 'correlation' + if self.params.merging.error.mm24.cc_after_pr: + self.cc_key = 'correlation_after_post' + else: + self.cc_key = 'correlation' def __repr__(self): return 'Adjust intensity errors -- mm24' @@ -54,12 +58,17 @@ def modify_errors(self, reflections): self.run_minimizer() if self.params.merging.error.mm24.do_diagnostics: + self.verify_derivatives() self.plot_diagnostics(reflections) # Finally update the variances of each reflection as per Eq (10) in Brewster et. al (2019) + if self.cc_key: + correlation = reflections[self.cc_key] + else: + correlation = None reflections['intensity.sum.variance'] = self._get_var_mm24( reflections['intensity.sum.variance'], reflections['biased_mean'], - reflections[self.cc_key] + correlation ) del reflections['biased_mean'] return reflections @@ -77,8 +86,9 @@ def pairing(k1, k2): pairwise_differences_normalized = flex.double() counting_stats_var_i = flex.double() counting_stats_var_j = flex.double() - correlation_i = flex.double() - correlation_j = flex.double() + if self.cc_key: + correlation_i = flex.double() + correlation_j = flex.double() number_of_reflections = 0 for refls in reflection_table_utils.get_next_hkl_reflection_table(reflections): @@ -92,7 +102,8 @@ def pairing(k1, k2): if number_of_measurements > self.params.merging.minimum_multiplicity: I = refls['intensity.sum.value'].as_numpy_array() var_cs = refls['intensity.sum.variance'].as_numpy_array() - correlation = refls[self.cc_key].as_numpy_array() + if self.cc_key: + correlation = refls[self.cc_key].as_numpy_array() number_of_reflections += I.size self.biased_mean_count.extend(flex.double(I.size, refls_biased_mean[0])) indices = np.triu_indices(n=I.size, k=1) @@ -111,7 +122,8 @@ def pairing(k1, k2): rng.shuffle(sort_indices) I = I[sort_indices] var_cs = var_cs[sort_indices] - correlation = correlation[sort_indices] + if self.cc_key: + correlation = correlation[sort_indices] # this option is for performance trade-offs if N > 1000: subset_indices = rng.choice( @@ -129,15 +141,17 @@ def pairing(k1, k2): biased_mean.extend(flex.double(N, refls_biased_mean[0])) counting_stats_var_i.extend(flex.double(var_cs[indices[0]])) counting_stats_var_j.extend(flex.double(var_cs[indices[1]])) - correlation_i.extend(flex.double(correlation[indices[0]])) - correlation_j.extend(flex.double(correlation[indices[1]])) + if self.cc_key: + correlation_i.extend(flex.double(correlation[indices[0]])) + correlation_j.extend(flex.double(correlation[indices[1]])) self.work_table['pairwise_differences'] = pairwise_differences self.work_table['biased_mean'] = biased_mean self.work_table['counting_stats_var_i'] = counting_stats_var_i self.work_table['counting_stats_var_j'] = counting_stats_var_j - self.work_table['correlation_i'] = correlation_i - self.work_table['correlation_j'] = correlation_j + if self.cc_key: + self.work_table['correlation_i'] = correlation_i + self.work_table['correlation_j'] = correlation_j reflections['biased_mean'] = biased_mean_to_reflections self.logger.log(f"Number of work reflections selected: {number_of_reflections}") @@ -252,9 +266,12 @@ def target_fun_scalar(sadd, bin_centers, mean_differences): bin_centers = bin_centers[positive_indices] mean_differences = mean_differences[positive_indices] - self.sadd = [0, 0.001, 0.001] - self.sadd[1] = 0.001 - self.sadd[2] = 0.001 + if self.cc_key: + self.sadd = [0, 0.001, 0.001] + self.sadd[1] = 0.001 + self.sadd[2] = 0.001 + else: + self.sadd = [0] if self.expected_sf is None: results = scipy.optimize.minimize( target_fun_bfgs, @@ -294,7 +311,10 @@ def target_fun_scalar(sadd, bin_centers, mean_differences): else: self.sfac = 0 - self.sadd = [0, 0, 0] + if self.cc_key: + self.sadd = [0, 0, 0] + else: + self.sadd = [0] self.sfac = self.mpi_helper.comm.bcast(self.sfac, root=0) self.sadd = self.mpi_helper.comm.bcast(self.sadd, root=0) @@ -304,38 +324,37 @@ def run_minimizer(self): comm = self.mpi_helper.comm MPI = self.mpi_helper.MPI size = self.mpi_helper.size + if self.params.merging.error.mm24.tuning_param_opt: - param_offset = 2 param_shift = 1 self.x = flex.double([self.tuning_param, self.sfac, *self.sadd]) else: - param_offset = 1 param_shift = 0 self.x = flex.double([self.sfac, *self.sadd]) - self.n = 3 + param_offset - if self.mpi_helper.rank == 0: - self.logger.main_log( - 'Initial Parameter Estimates = ' - + f'sfac: {self.sfac} ' - + f'sadd: {self.sadd[0]} ' - + f'nu: {self.tuning_param} ' - ) - l = flex.double(self.n, 1e-8) - u = flex.double(self.n, 0) + n_parameters = len(self.x) + l = flex.double(n_parameters, 1e-8) + u = flex.double(n_parameters, 0) if self.params.merging.error.mm24.tuning_param_opt: # normalization for the truncated t-distribution is numerically unstable for nu < 2 - l[0] = 2 - if self.x[0] < 2: - self.x[0] = 2 - for degree_index in range(3): - l[degree_index + param_offset] = -1000 + l[0] = 2.5 + if self.x[0] < 2.5: + self.x[0] = 2.5 + for degree_index in range(param_shift, n_parameters): + l[degree_index] = -1000 if self.mpi_helper.rank == 0: self.minimizer = lbfgsb.minimizer( - n = self.n, + n = n_parameters, l = l, u = u, - nbd = flex.int(self.n, 1), + nbd = flex.int(n_parameters, 1), + ) + if self.mpi_helper.rank == 0: + self.logger.main_log( + 'Initial Parameter Estimates = ' + + f'sfac: {self.sfac} ' + + f'sadd: {self.sadd[0]} ' + + f'nu: {self.tuning_param} ' ) while True: self.compute_functional_and_gradients() @@ -347,12 +366,10 @@ def run_minimizer(self): tuning_param = f'{self.tuning_param:0.3f}' self.sfac = self.x[0 + param_shift] self.sadd = self.x[1 + param_shift:] - sfac = f'{self.sfac:0.3f}' - sadd = [f'{self.sadd[i]:0.3f}' for i in range(3)] log_out = 'intermediate minimization results = '\ + f'loss: {self.L:.2f} '\ - + f'sfac: {sfac} '\ - + f'sadd: {sadd} ' + + f'sfac: {self.sfac:0.3f} '\ + + f'sadd: {[f"{value:0.3f}" for value in self.sadd]} ' if self.params.merging.error.mm24.tuning_param_opt: log_out += f'nu: {tuning_param}' self.logger.main_log(log_out) @@ -372,12 +389,10 @@ def run_minimizer(self): if self.mpi_helper.rank == 0: tuning_param = f'{self.tuning_param:0.3f}' - sfac = f'{self.sfac:0.3f}' - sadd = [f'{self.sadd[i]:0.3f}' for i in range(3)] log_out = 'FINAL mm24 VALUES = '\ + f'loss: {self.L:.2f} '\ - + f'sfac: {sfac} '\ - + f'sadd: {sadd} ' + + f'sfac: {self.sfac:0.3f} '\ + + f'sadd: {[f"{value:0.3f}" for value in self.sadd]} ' if self.params.merging.error.mm24.tuning_param_opt: log_out += f'nu: {tuning_param}' self.logger.main_log(log_out) @@ -393,10 +408,6 @@ def compute_functional_and_gradients(self): def verify_derivatives(self): shift = 0.000001 import copy - if self.params.merging.error.mm24.tuning_param_opt: - self.n = 6 - else: - self.n = 5 self.calculate_functional() sfac = copy.copy(self.sfac) @@ -439,7 +450,7 @@ def verify_derivatives(self): print(f'der_wrt_sfac numerical: {check_der_wrt_sfac} analytical {der_wrt_sfac}') # sadd: - for degree_index in range(3): + for degree_index in range(len(self.sadd)): if sadd[degree_index] == 0: self.sadd[degree_index] = shift else: @@ -529,13 +540,17 @@ def _loss_function_t_v_opt(self, differences, var_i, var_j): return L, dL_dvar, dL_dv def _get_sadd2(self, correlation): - term1 = flex.exp(-self.sadd[1] * correlation) - sadd2 = self.sadd[0]**2 * term1 + self.sadd[2]**2 - dsadd2_dsaddi = [ - 2 * self.sadd[0] * term1, - -correlation * self.sadd[0]**2 * term1, - 2 * self.sadd[2] * flex.double(len(correlation), 1) - ] + if correlation: + term1 = flex.exp(-self.sadd[1] * correlation) + sadd2 = self.sadd[0]**2 * term1 + self.sadd[2]**2 + dsadd2_dsaddi = [ + 2 * self.sadd[0] * term1, + -correlation * self.sadd[0]**2 * term1, + 2 * self.sadd[2] * flex.double(len(correlation), 1) + ] + else: + sadd2 = self.sadd[0]**2 + dsadd2_dsaddi = [2 * self.sadd[0]] return sadd2, dsadd2_dsaddi def _get_var_mm24(self, counting_err, biased_mean, correlation, return_der=False): @@ -553,22 +568,28 @@ def calculate_functional(self): MPI = self.mpi_helper.MPI L_bin_rank = flex.double(self.number_of_intensity_bins, 0) dL_dsfac_bin_rank = flex.double(self.number_of_intensity_bins, 0) - dL_dsadd_bin_rank = [flex.double(self.number_of_intensity_bins, 0) for i in range(3)] + dL_dsadd_bin_rank = [flex.double(self.number_of_intensity_bins, 0) for i in range(len(self.sadd))] if self.params.merging.error.mm24.tuning_param_opt: dL_dnu_bin_rank = flex.double(self.number_of_intensity_bins, 0) for bin_index, differences in enumerate(self.intensity_bins): if len(differences) > 0: + if self.cc_key: + correlation_i = differences['correlation_i'] + correlation_j = differences['correlation_j'] + else: + correlation_i = None + correlation_j = None var_i, dvar_i_dsfac, dvar_i_dsadd2, dsadd2_i_dsaddi = self._get_var_mm24( differences['counting_stats_var_i'], differences['biased_mean'], - differences['correlation_i'], + correlation_i, return_der=True ) var_j, dvar_j_dsfac, dvar_j_dsadd2, dsadd2_j_dsaddi = self._get_var_mm24( differences['counting_stats_var_j'], differences['biased_mean'], - differences['correlation_j'], + correlation_j, return_der=True ) @@ -589,15 +610,15 @@ def calculate_functional(self): L_bin_rank[bin_index] = flex.sum(L_in_bin) dL_dsfac_bin_rank[bin_index] = flex.sum(dL_dvar_x * (dvar_i_dsfac + dvar_j_dsfac)) - for degree_index in range(3): + for degree_index in range(len(self.sadd)): dL_dsadd_bin_rank[degree_index][bin_index] = flex.sum(dL_dvar_x * ( dvar_i_dsadd2 * dsadd2_i_dsaddi[degree_index] + dvar_j_dsadd2 * dsadd2_j_dsaddi[degree_index] )) L_bin = comm.reduce(L_bin_rank, MPI.SUM, root=0) dL_dsfac_bin = comm.reduce(dL_dsfac_bin_rank, MPI.SUM, root=0) - dL_dsadd_bin = [None for i in range(3)] - for degree_index in range(3): + dL_dsadd_bin = [None for i in range(len(self.sadd))] + for degree_index in range(len(self.sadd)): dL_dsadd_bin[degree_index] = comm.reduce(dL_dsadd_bin_rank[degree_index], MPI.SUM, root=0) if self.params.merging.error.mm24.tuning_param_opt: dL_dnu_bin = comm.reduce(dL_dnu_bin_rank, MPI.SUM, root=0) @@ -605,8 +626,8 @@ def calculate_functional(self): if self.mpi_helper.rank == 0: self.L = flex.sum(self.bin_weighting * L_bin) self.dL_dsfac = flex.sum(self.bin_weighting * dL_dsfac_bin) - self.dL_dsadd = [0 for i in range(3)] - for degree_index in range(3): + self.dL_dsadd = [0 for i in range(len(self.sadd))] + for degree_index in range(len(self.sadd)): self.dL_dsadd[degree_index] = flex.sum(self.bin_weighting * dL_dsadd_bin[degree_index]) if self.params.merging.error.mm24.tuning_param_opt: self.dL_dnu = flex.sum(self.bin_weighting * dL_dnu_bin) @@ -659,10 +680,14 @@ def min_fun_t(c, df): bin_index = np.searchsorted(self.intensity_bin_limits, biased_mean) - 1 if biased_mean > self.intensity_bin_limits[0] and biased_mean < self.intensity_bin_limits[-1]: I = refls['intensity.sum.value'].as_numpy_array() + if self.cc_key: + correlation = refls[self.cc_key] + else: + correlation = None var_mm24 = self._get_var_mm24( refls['intensity.sum.variance'], flex.double(len(refls), biased_mean), - refls[self.cc_key] + correlation ).as_numpy_array() # calculate the median difference for the pairwise differences @@ -686,16 +711,22 @@ def min_fun_t(c, df): pairwise_differences = [] for bin_index, differences in enumerate(self.intensity_bins): if len(differences) > 0: + if self.cc_key: + correlation_i = differences['correlation_i'] + correlation_j = differences['correlation_j'] + else: + correlation_i = None + correlation_j = None var_i = self._get_var_mm24( differences['counting_stats_var_i'], differences['biased_mean'], - differences['correlation_i'], + correlation_i, return_der=False ) var_j = self._get_var_mm24( differences['counting_stats_var_j'], differences['biased_mean'], - differences['correlation_j'], + correlation_j, return_der=False ) normalized_differences = differences['pairwise_differences'] / flex.sqrt(var_i + var_j) @@ -785,32 +816,33 @@ def min_fun_t(c, df): plt.close() # Get the correlations for later plotting - cc_all = self.mpi_helper.gather_variable_length_numpy_arrays( - np.unique(self.work_table['correlation_i'].as_numpy_array()), root=0, dtype=float - ) - if self.mpi_helper.rank == 0: - # CC & sadd plots # - bins = np.linspace(cc_all.min(), cc_all.max(), 101) - dbin = bins[1] - bins[0] - centers = (bins[1:] + bins[:-1]) / 2 - hist_all, _ = np.histogram(cc_all, bins=bins) - - hist_color = np.array([0, 49, 60]) / 256 - line_color = np.array([213, 120, 0]) / 256 - sadd2, _ = self._get_sadd2(flex.double(centers)) - fig, axes_hist = plt.subplots(1, 1, figsize=(5, 3)) - axes_sadd = axes_hist.twinx() - axes_hist.bar(centers, hist_all / 1000, width=dbin, color=hist_color) - axes_sadd.plot(centers, self.sfac**2 * sadd2, color=line_color) - axes_hist.set_xlabel('Correlation Coefficient') - axes_hist.set_ylabel('Lattices (x1,000)') - axes_sadd.set_ylabel('$s_{\mathrm{fac}}^2 \\times s_{\mathrm{add}}^2$') - fig.tight_layout() - fig.savefig(os.path.join( - self.params.output.output_dir, - self.params.output.prefix + '_sadd.png' - )) - plt.close() + if self.cc_key: + cc_all = self.mpi_helper.gather_variable_length_numpy_arrays( + np.unique(self.work_table['correlation_i'].as_numpy_array()), root=0, dtype=float + ) + if self.mpi_helper.rank == 0: + # CC & sadd plots # + bins = np.linspace(cc_all.min(), cc_all.max(), 101) + dbin = bins[1] - bins[0] + centers = (bins[1:] + bins[:-1]) / 2 + hist_all, _ = np.histogram(cc_all, bins=bins) + + hist_color = np.array([0, 49, 60]) / 256 + line_color = np.array([213, 120, 0]) / 256 + sadd2, _ = self._get_sadd2(flex.double(centers)) + fig, axes_hist = plt.subplots(1, 1, figsize=(5, 3)) + axes_sadd = axes_hist.twinx() + axes_hist.bar(centers, hist_all / 1000, width=dbin, color=hist_color) + axes_sadd.plot(centers, self.sfac**2 * sadd2, color=line_color) + axes_hist.set_xlabel('Correlation Coefficient') + axes_hist.set_ylabel('Lattices (x1,000)') + axes_sadd.set_ylabel('$s_{\mathrm{fac}}^2 \\times s_{\mathrm{add}}^2$') + fig.tight_layout() + fig.savefig(os.path.join( + self.params.output.output_dir, + self.params.output.prefix + '_sadd.png' + )) + plt.close() if __name__ == '__main__': from xfel.merging.application.worker import exercise_worker