From c1fcb33096f263ed9521aa5a4504cc8d37c6ca87 Mon Sep 17 00:00:00 2001 From: James Beilsten-Edmands <30625594+jbeilstenedmands@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:06:04 +0100 Subject: [PATCH] Allow scale on batch objects --- src/xia2/Modules/SSX/data_reduction_base.py | 3 +- .../Modules/SSX/data_reduction_programs.py | 194 +++++++++--------- src/xia2/Modules/SSX/data_reduction_simple.py | 13 +- 3 files changed, 108 insertions(+), 102 deletions(-) diff --git a/src/xia2/Modules/SSX/data_reduction_base.py b/src/xia2/Modules/SSX/data_reduction_base.py index ec206114c..5b398d1a0 100644 --- a/src/xia2/Modules/SSX/data_reduction_base.py +++ b/src/xia2/Modules/SSX/data_reduction_base.py @@ -133,6 +133,7 @@ def __init__( self._integrated_data: List[FilePair] = [] self._filtered_batches_to_process: List[Batch] = [] self._files_to_scale: List[FilePair] = [] + self._batches_to_scale: List[Batch] = [] self._files_to_merge: List[FilePair] = [] if not data: @@ -258,7 +259,7 @@ def _reindex(self) -> None: raise NotImplementedError def _prepare_for_scaling(self, good_crystals_data) -> None: - self._files_to_scale = split_integrated_data( + self._batches_to_scale = split_integrated_data( self._filter_wd, good_crystals_data, self._integrated_data, diff --git a/src/xia2/Modules/SSX/data_reduction_programs.py b/src/xia2/Modules/SSX/data_reduction_programs.py index 026608147..02cb3f7eb 100644 --- a/src/xia2/Modules/SSX/data_reduction_programs.py +++ b/src/xia2/Modules/SSX/data_reduction_programs.py @@ -3,7 +3,6 @@ import concurrent.futures import copy -import functools import json import logging import math @@ -101,7 +100,7 @@ def filter_( def split_integrated_data( working_directory, good_crystals_data, integrated_data, reduction_params ) -> List[Batch]: - new_batches_to_process = split_filtered_data_2( + new_batches_to_process = split_filtered_data( working_directory, integrated_data, good_crystals_data, @@ -565,7 +564,68 @@ def scale_against_reference( ) -def scale( +def scale_on_batches( + working_directory: Path, + batches_to_scale: List[Batch], + reduction_params: ReductionParams, + name="", +) -> ProgramResult: + logfile = "dials.scale.log" + if name: + logfile = f"dials.scale.{name}.log" + with run_in_directory(working_directory), log_to_file( + logfile + ) as dials_logger, record_step("dials.scale"): + # Setup scaling + input_ = "" + + all_expts = ExperimentList([]) + tables = [] + for batch in batches_to_scale: + for fp, ids in batch.file_to_identifiers.items(): + table = flex.reflection_table.from_file(fp.refl) + expts = load.experiment_list(fp.expt, check_format=False) + if len(ids) < len(expts): + expts.select_on_experiment_identifiers(list(ids)) + table = table.select_on_experiment_identifiers(list(ids)) + table.reset_ids() + all_expts.extend(expts) + tables.append(table) + expts = all_expts + + params, diff_phil = _extract_scaling_params(reduction_params) + dials_logger.info( + "The following parameters have been modified:\n" + + input_ + + f"{diff_phil.as_str()}" + ) + + # Run the scaling using the algorithm class to give access to scaler + scaler = ScalingAlgorithm(params, expts, tables) + scaler.run() + scaled_expts, scaled_table = scaler.finish() + if name: + out_expt = f"scaled.{name}.expt" + out_refl = f"scaled.{name}.refl" + else: + out_expt = "scaled.expt" + out_refl = "scaled.refl" + + dials_logger.info(f"Saving scaled experiments to {out_expt}") + scaled_expts.as_file(out_expt) + dials_logger.info(f"Saving scaled reflections to {out_refl}") + scaled_table.as_file(out_refl) + + return ProgramResult( + working_directory / out_expt, + working_directory / out_refl, + working_directory / logfile, + None, + None, + ) + + +def scale_on_files( working_directory: Path, files_to_scale: List[FilePair], reduction_params: ReductionParams, @@ -719,6 +779,27 @@ def cosym_against_reference( ) +def combined_files_for_batch(batch): + all_expts = ExperimentList([]) + tables = [] + for fp, ids in batch.file_to_identifiers.items(): + table = flex.reflection_table.from_file(fp.refl) + expts = load.experiment_list(fp.expt, check_format=False) + if len(ids) < len(expts): + expts.select_on_experiment_identifiers(list(ids)) + table = table.select_on_experiment_identifiers(list(ids)) + table.reset_ids() + all_expts.extend(expts) + tables.append(table) + if len(tables) > 1: + table = flex.reflection_table.concat(tables) + table.reset_ids() + else: + table = tables[0] + expts = all_expts + return expts, table + + def individual_cosym( working_directory: Path, batch: Batch, @@ -739,23 +820,7 @@ def individual_cosym( ) # cosym_params.cc_star_threshold = 0.1 # cosym_params.angular_separation_threshold = 5 - all_expts = ExperimentList([]) - tables = [] - for fp, ids in batch.file_to_identifiers.items(): - table = flex.reflection_table.from_file(fp.refl) - expts = load.experiment_list(fp.expt, check_format=False) - if len(ids) < len(expts): - expts.select_on_experiment_identifiers(list(ids)) - table = table.select_on_experiment_identifiers(list(ids)) - table.reset_ids() - all_expts.extend(expts) - tables.append(table) - if len(tables) > 1: - table = flex.reflection_table.concat(tables) - table.reset_ids() - else: - table = tables[0] - expts = all_expts + expts, table = combined_files_for_batch(batch) tables = table.split_by_experiment_id() # now run cosym @@ -976,87 +1041,17 @@ def select_crystals_close_to( return good_crystals_data -def split_filtered_data( - working_directory: Path, - new_data: List[FilePair], - good_crystals_data: CrystalsDict, - min_batch_size: int, -) -> List[FilePair]: - if not Path.is_dir(working_directory): - Path.mkdir(working_directory) - with record_step("splitting"): - data_to_reindex: List = [] - n_cryst = sum(len(v.identifiers) for v in good_crystals_data.values()) - n_batches = max(math.floor(n_cryst / min_batch_size), 1) - stride = n_cryst / n_batches - # make sure last batch has at least the batch size - splits = [int(math.floor(i * stride)) for i in range(n_batches)] - splits.append(n_cryst) - template = functools.partial( - "split_{index:0{fmt:d}d}".format, fmt=len(str(n_batches)) - ) - leftover_expts = ExperimentList([]) - leftover_refls: List[flex.reflection_table] = [] - n_batch_output = 0 - n_required = splits[1] - splits[0] - for file_pair in new_data: - expts = load.experiment_list(file_pair.expt, check_format=False) - refls = flex.reflection_table.from_file(file_pair.refl) - good_crystals_this = good_crystals_data[str(file_pair.expt)] - if not good_crystals_this.crystals: - continue - good_identifiers = good_crystals_this.identifiers - if not good_crystals_this.keep_all_original: - expts.select_on_experiment_identifiers(good_identifiers) - refls = refls.select_on_experiment_identifiers(good_identifiers) - refls.reset_ids() - leftover_expts.extend(expts) - leftover_refls.append(refls) - while len(leftover_expts) >= n_required: - sub_expt = leftover_expts[0:n_required] - if len(leftover_refls) > 1: - leftover_refls = [flex.reflection_table.concat(leftover_refls)] - # concat guarantees that ids are ordered 0...n-1 - sub_refl = leftover_refls[0].select( - leftover_refls[0]["id"] < n_required - ) - leftover_refls = [ - leftover_refls[0].select(leftover_refls[0]["id"] >= n_required) - ] - leftover_refls[0].reset_ids() - sub_refl.reset_ids() # necessary? - leftover_expts = leftover_expts[n_required:] - out_expt = working_directory / ( - template(index=n_batch_output) + ".expt" - ) - out_refl = working_directory / ( - template(index=n_batch_output) + ".refl" - ) - sub_expt.as_file(out_expt) - sub_refl.as_file(out_refl) - data_to_reindex.append(FilePair(out_expt, out_refl)) - n_batch_output += 1 - if n_batch_output == len(splits) - 1: - break - n_required = splits[n_batch_output + 1] - splits[n_batch_output] - assert n_batch_output == len(splits) - 1 - assert not len(leftover_expts) - for table in leftover_refls: - assert table.size() == 0 - return data_to_reindex - - class Batch(object): def __init__(self): self.file_to_identifiers = {} # FilePair to identifiers -def split_filtered_data_2( +def split_filtered_data( working_directory: Path, new_data: List[FilePair], good_crystals_data: CrystalsDict, min_batch_size: int, -) -> List[FilePair]: +) -> List[Batch]: n_cryst = sum(len(v.identifiers) for v in good_crystals_data.values()) n_batches = max(math.floor(n_cryst / min_batch_size), 1) @@ -1066,7 +1061,8 @@ def split_filtered_data_2( # make sure last batch has at least the batch size splits = [int(math.floor(i * stride)) for i in range(n_batches)] splits.append(n_cryst) - leftover_identifiers = flex.std_string([]) + # leftover_identifiers = flex.std_string([]) + n_leftover = 0 n_batch_output = 0 n_required = splits[1] - splits[0] current_fps = [] @@ -1076,12 +1072,13 @@ def split_filtered_data_2( if not good_crystals_this.crystals: continue good_identifiers = good_crystals_this.identifiers - leftover_identifiers.extend(good_identifiers) + # leftover_identifiers.extend(good_identifiers) + n_leftover += len(good_identifiers) current_fps.append(file_pair) current_identifier_lists.append(good_identifiers) - while len(leftover_identifiers) >= n_required: - n_leftover = len(leftover_identifiers) + while n_leftover >= n_required: + # n_leftover = len(leftover_identifiers) last_fp = current_fps.pop() ids = current_identifier_lists.pop() @@ -1098,12 +1095,13 @@ def split_filtered_data_2( current_fps = [last_fp] current_identifier_lists = [sub_ids_last_leftover] n_batch_output += 1 - leftover_identifiers = leftover_identifiers[n_required:] + # leftover_identifiers = leftover_identifiers[n_required:] + n_leftover -= n_required if n_batch_output == len(splits) - 1: break n_required = splits[n_batch_output + 1] - splits[n_batch_output] assert n_batch_output == len(splits) - 1 - assert not len(leftover_identifiers) + assert not n_leftover # len(leftover_identifiers) return batches diff --git a/src/xia2/Modules/SSX/data_reduction_simple.py b/src/xia2/Modules/SSX/data_reduction_simple.py index ccdbec83f..47c223336 100644 --- a/src/xia2/Modules/SSX/data_reduction_simple.py +++ b/src/xia2/Modules/SSX/data_reduction_simple.py @@ -9,7 +9,8 @@ FilePair, cosym_reindex, parallel_cosym, - scale, + scale_on_batches, + scale_on_files, ) xia2_logger = logging.getLogger(__name__) @@ -42,8 +43,14 @@ def _scale(self) -> None: if not Path.is_dir(self._scale_wd): Path.mkdir(self._scale_wd) - - result = scale(self._scale_wd, self._files_to_scale, self._reduction_params) + if self._batches_to_scale: + result = scale_on_batches( + self._scale_wd, self._batches_to_scale, self._reduction_params + ) + else: + result = scale_on_files( + self._scale_wd, self._files_to_scale, self._reduction_params + ) xia2_logger.info("Completed scaling of all data") self._files_to_merge = [FilePair(result.exptfile, result.reflfile)] FileHandler.record_data_file(result.exptfile)