Skip to content

Commit

Permalink
Allow scale on batch objects
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeilstenedmands committed Aug 17, 2023
1 parent d711133 commit c1fcb33
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 102 deletions.
3 changes: 2 additions & 1 deletion src/xia2/Modules/SSX/data_reduction_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
self._integrated_data: List[FilePair] = []
self._filtered_batches_to_process: List[Batch] = []
self._files_to_scale: List[FilePair] = []
self._batches_to_scale: List[Batch] = []
self._files_to_merge: List[FilePair] = []

if not data:
Expand Down Expand Up @@ -258,7 +259,7 @@ def _reindex(self) -> None:
raise NotImplementedError

def _prepare_for_scaling(self, good_crystals_data) -> None:
self._files_to_scale = split_integrated_data(
self._batches_to_scale = split_integrated_data(
self._filter_wd,
good_crystals_data,
self._integrated_data,
Expand Down
194 changes: 96 additions & 98 deletions src/xia2/Modules/SSX/data_reduction_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import concurrent.futures
import copy
import functools
import json
import logging
import math
Expand Down Expand Up @@ -101,7 +100,7 @@ def filter_(
def split_integrated_data(
working_directory, good_crystals_data, integrated_data, reduction_params
) -> List[Batch]:
new_batches_to_process = split_filtered_data_2(
new_batches_to_process = split_filtered_data(
working_directory,
integrated_data,
good_crystals_data,
Expand Down Expand Up @@ -565,7 +564,68 @@ def scale_against_reference(
)


def scale(
def scale_on_batches(
working_directory: Path,
batches_to_scale: List[Batch],
reduction_params: ReductionParams,
name="",
) -> ProgramResult:
logfile = "dials.scale.log"
if name:
logfile = f"dials.scale.{name}.log"
with run_in_directory(working_directory), log_to_file(
logfile
) as dials_logger, record_step("dials.scale"):
# Setup scaling
input_ = ""

all_expts = ExperimentList([])
tables = []
for batch in batches_to_scale:
for fp, ids in batch.file_to_identifiers.items():
table = flex.reflection_table.from_file(fp.refl)
expts = load.experiment_list(fp.expt, check_format=False)
if len(ids) < len(expts):
expts.select_on_experiment_identifiers(list(ids))
table = table.select_on_experiment_identifiers(list(ids))
table.reset_ids()
all_expts.extend(expts)
tables.append(table)
expts = all_expts

params, diff_phil = _extract_scaling_params(reduction_params)
dials_logger.info(
"The following parameters have been modified:\n"
+ input_
+ f"{diff_phil.as_str()}"
)

# Run the scaling using the algorithm class to give access to scaler
scaler = ScalingAlgorithm(params, expts, tables)
scaler.run()
scaled_expts, scaled_table = scaler.finish()
if name:
out_expt = f"scaled.{name}.expt"
out_refl = f"scaled.{name}.refl"
else:
out_expt = "scaled.expt"
out_refl = "scaled.refl"

dials_logger.info(f"Saving scaled experiments to {out_expt}")
scaled_expts.as_file(out_expt)
dials_logger.info(f"Saving scaled reflections to {out_refl}")
scaled_table.as_file(out_refl)

return ProgramResult(
working_directory / out_expt,
working_directory / out_refl,
working_directory / logfile,
None,
None,
)


def scale_on_files(
working_directory: Path,
files_to_scale: List[FilePair],
reduction_params: ReductionParams,
Expand Down Expand Up @@ -719,6 +779,27 @@ def cosym_against_reference(
)


def combined_files_for_batch(batch):
all_expts = ExperimentList([])
tables = []
for fp, ids in batch.file_to_identifiers.items():
table = flex.reflection_table.from_file(fp.refl)
expts = load.experiment_list(fp.expt, check_format=False)
if len(ids) < len(expts):
expts.select_on_experiment_identifiers(list(ids))
table = table.select_on_experiment_identifiers(list(ids))
table.reset_ids()
all_expts.extend(expts)
tables.append(table)
if len(tables) > 1:
table = flex.reflection_table.concat(tables)
table.reset_ids()
else:
table = tables[0]
expts = all_expts
return expts, table


def individual_cosym(
working_directory: Path,
batch: Batch,
Expand All @@ -739,23 +820,7 @@ def individual_cosym(
)
# cosym_params.cc_star_threshold = 0.1
# cosym_params.angular_separation_threshold = 5
all_expts = ExperimentList([])
tables = []
for fp, ids in batch.file_to_identifiers.items():
table = flex.reflection_table.from_file(fp.refl)
expts = load.experiment_list(fp.expt, check_format=False)
if len(ids) < len(expts):
expts.select_on_experiment_identifiers(list(ids))
table = table.select_on_experiment_identifiers(list(ids))
table.reset_ids()
all_expts.extend(expts)
tables.append(table)
if len(tables) > 1:
table = flex.reflection_table.concat(tables)
table.reset_ids()
else:
table = tables[0]
expts = all_expts
expts, table = combined_files_for_batch(batch)

tables = table.split_by_experiment_id()
# now run cosym
Expand Down Expand Up @@ -976,87 +1041,17 @@ def select_crystals_close_to(
return good_crystals_data


def split_filtered_data(
working_directory: Path,
new_data: List[FilePair],
good_crystals_data: CrystalsDict,
min_batch_size: int,
) -> List[FilePair]:
if not Path.is_dir(working_directory):
Path.mkdir(working_directory)
with record_step("splitting"):
data_to_reindex: List = []
n_cryst = sum(len(v.identifiers) for v in good_crystals_data.values())
n_batches = max(math.floor(n_cryst / min_batch_size), 1)
stride = n_cryst / n_batches
# make sure last batch has at least the batch size
splits = [int(math.floor(i * stride)) for i in range(n_batches)]
splits.append(n_cryst)
template = functools.partial(
"split_{index:0{fmt:d}d}".format, fmt=len(str(n_batches))
)
leftover_expts = ExperimentList([])
leftover_refls: List[flex.reflection_table] = []
n_batch_output = 0
n_required = splits[1] - splits[0]
for file_pair in new_data:
expts = load.experiment_list(file_pair.expt, check_format=False)
refls = flex.reflection_table.from_file(file_pair.refl)
good_crystals_this = good_crystals_data[str(file_pair.expt)]
if not good_crystals_this.crystals:
continue
good_identifiers = good_crystals_this.identifiers
if not good_crystals_this.keep_all_original:
expts.select_on_experiment_identifiers(good_identifiers)
refls = refls.select_on_experiment_identifiers(good_identifiers)
refls.reset_ids()
leftover_expts.extend(expts)
leftover_refls.append(refls)
while len(leftover_expts) >= n_required:
sub_expt = leftover_expts[0:n_required]
if len(leftover_refls) > 1:
leftover_refls = [flex.reflection_table.concat(leftover_refls)]
# concat guarantees that ids are ordered 0...n-1
sub_refl = leftover_refls[0].select(
leftover_refls[0]["id"] < n_required
)
leftover_refls = [
leftover_refls[0].select(leftover_refls[0]["id"] >= n_required)
]
leftover_refls[0].reset_ids()
sub_refl.reset_ids() # necessary?
leftover_expts = leftover_expts[n_required:]
out_expt = working_directory / (
template(index=n_batch_output) + ".expt"
)
out_refl = working_directory / (
template(index=n_batch_output) + ".refl"
)
sub_expt.as_file(out_expt)
sub_refl.as_file(out_refl)
data_to_reindex.append(FilePair(out_expt, out_refl))
n_batch_output += 1
if n_batch_output == len(splits) - 1:
break
n_required = splits[n_batch_output + 1] - splits[n_batch_output]
assert n_batch_output == len(splits) - 1
assert not len(leftover_expts)
for table in leftover_refls:
assert table.size() == 0
return data_to_reindex


class Batch(object):
def __init__(self):
self.file_to_identifiers = {} # FilePair to identifiers


def split_filtered_data_2(
def split_filtered_data(
working_directory: Path,
new_data: List[FilePair],
good_crystals_data: CrystalsDict,
min_batch_size: int,
) -> List[FilePair]:
) -> List[Batch]:

n_cryst = sum(len(v.identifiers) for v in good_crystals_data.values())
n_batches = max(math.floor(n_cryst / min_batch_size), 1)
Expand All @@ -1066,7 +1061,8 @@ def split_filtered_data_2(
# make sure last batch has at least the batch size
splits = [int(math.floor(i * stride)) for i in range(n_batches)]
splits.append(n_cryst)
leftover_identifiers = flex.std_string([])
# leftover_identifiers = flex.std_string([])
n_leftover = 0
n_batch_output = 0
n_required = splits[1] - splits[0]
current_fps = []
Expand All @@ -1076,12 +1072,13 @@ def split_filtered_data_2(
if not good_crystals_this.crystals:
continue
good_identifiers = good_crystals_this.identifiers
leftover_identifiers.extend(good_identifiers)
# leftover_identifiers.extend(good_identifiers)
n_leftover += len(good_identifiers)
current_fps.append(file_pair)
current_identifier_lists.append(good_identifiers)

while len(leftover_identifiers) >= n_required:
n_leftover = len(leftover_identifiers)
while n_leftover >= n_required:
# n_leftover = len(leftover_identifiers)

last_fp = current_fps.pop()
ids = current_identifier_lists.pop()
Expand All @@ -1098,12 +1095,13 @@ def split_filtered_data_2(
current_fps = [last_fp]
current_identifier_lists = [sub_ids_last_leftover]
n_batch_output += 1
leftover_identifiers = leftover_identifiers[n_required:]
# leftover_identifiers = leftover_identifiers[n_required:]
n_leftover -= n_required
if n_batch_output == len(splits) - 1:
break
n_required = splits[n_batch_output + 1] - splits[n_batch_output]
assert n_batch_output == len(splits) - 1
assert not len(leftover_identifiers)
assert not n_leftover # len(leftover_identifiers)

return batches

Expand Down
13 changes: 10 additions & 3 deletions src/xia2/Modules/SSX/data_reduction_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
FilePair,
cosym_reindex,
parallel_cosym,
scale,
scale_on_batches,
scale_on_files,
)

xia2_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,8 +43,14 @@ def _scale(self) -> None:

if not Path.is_dir(self._scale_wd):
Path.mkdir(self._scale_wd)

result = scale(self._scale_wd, self._files_to_scale, self._reduction_params)
if self._batches_to_scale:
result = scale_on_batches(
self._scale_wd, self._batches_to_scale, self._reduction_params
)
else:
result = scale_on_files(
self._scale_wd, self._files_to_scale, self._reduction_params
)
xia2_logger.info("Completed scaling of all data")
self._files_to_merge = [FilePair(result.exptfile, result.reflfile)]
FileHandler.record_data_file(result.exptfile)
Expand Down

0 comments on commit c1fcb33

Please sign in to comment.