Skip to content

Commit

Permalink
Merge pull request #507 from DiamondLightSource/distortion-correction…
Browse files Browse the repository at this point in the history
…-wrapper

Add distortion correction wrapper
  • Loading branch information
dkazanc authored Oct 24, 2024
2 parents 5149ca4 + 743ad8f commit 75acffc
Show file tree
Hide file tree
Showing 18 changed files with 349 additions and 31 deletions.
36 changes: 27 additions & 9 deletions httomo/method_wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from httomo.preview import PreviewConfig
from httomo.runner.method_wrapper import MethodWrapper
from httomo.runner.methods_repository_interface import MethodRepository

Expand All @@ -7,6 +8,7 @@
# (add imports here when createing new wrappers)
import httomo.method_wrappers.datareducer
import httomo.method_wrappers.dezinging
import httomo.method_wrappers.distortion_correction
import httomo.method_wrappers.images
import httomo.method_wrappers.reconstruction
import httomo.method_wrappers.rotation
Expand All @@ -22,6 +24,7 @@ def make_method_wrapper(
module_path: str,
method_name: str,
comm: Comm,
preview_config: PreviewConfig,
save_result: Optional[bool] = None,
output_mapping: Dict[str, str] = {},
**kwargs,
Expand All @@ -39,6 +42,8 @@ def make_method_wrapper(
Path to the module where the method is in python notation, e.g. "httomolibgpu.prep.normalize"
method_name: str
Name of the method (function within the given module)
preview_config : PreviewConfig
Config for preview value from loader
comm: Comm
MPI communicator object
save_result: Optional[bool]
Expand Down Expand Up @@ -67,12 +72,25 @@ def make_method_wrapper(
+ f" are ambigious between {c.__name__} and {cls.__name__}"
)
cls = c
return cls(
method_repository=method_repository,
module_path=module_path,
method_name=method_name,
comm=comm,
save_result=save_result,
output_mapping=output_mapping,
**kwargs,
)

if cls.requires_preview():
return cls(
method_repository=method_repository,
module_path=module_path,
method_name=method_name,
comm=comm,
preview_config=preview_config,
save_result=save_result,
output_mapping=output_mapping,
**kwargs,
)
else:
return cls(
method_repository=method_repository,
module_path=module_path,
method_name=method_name,
comm=comm,
save_result=save_result,
output_mapping=output_mapping,
**kwargs,
)
62 changes: 62 additions & 0 deletions httomo/method_wrappers/distortion_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Dict, Optional

from mpi4py.MPI import Comm

from httomo.method_wrappers.generic import GenericMethodWrapper
from httomo.preview import PreviewConfig
from httomo.runner.methods_repository_interface import MethodRepository


class DistortionCorrectionWrapper(GenericMethodWrapper):
"""
Wrapper for distortion correction methods.
"""

@classmethod
def should_select_this_class(cls, module_path: str, method_name: str) -> bool:
return "distortion_correction" in method_name

@classmethod
def requires_preview(cls) -> bool:
return True

def __init__(
self,
method_repository: MethodRepository,
module_path: str,
method_name: str,
comm: Comm,
preview_config: PreviewConfig,
save_result: Optional[bool] = None,
output_mapping: Dict[str, str] = {},
**kwargs,
):
super().__init__(
method_repository,
module_path,
method_name,
comm,
save_result,
output_mapping,
**kwargs,
)
self._update_params_from_preview(preview_config)

def _update_params_from_preview(self, preview_config: PreviewConfig) -> None:
"""
Extract information from preview config to define the parameter values required for
distortion correction methods, and update `self._config_params`.
"""
SHIFT_PARAM_NAME = "shift_xy"
STEP_PARAM_NAME = "step_xy"
shift_param_value = [
preview_config.detector_x.start,
preview_config.detector_y.start,
]
step_param_value = [1, 1]
self.append_config_params(
{
SHIFT_PARAM_NAME: shift_param_value,
STEP_PARAM_NAME: step_param_value,
}
)
8 changes: 8 additions & 0 deletions httomo/method_wrappers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ def should_select_this_class(cls, module_path: str, method_name: str) -> bool:
"""
return False # pragma: no cover

@classmethod
def requires_preview(cls) -> bool:
"""
Whether the wrapper class needs the preview information from the loader to execute the
methods it wraps or not.
"""
return False

def __init__(
self,
method_repository: MethodRepository,
Expand Down
12 changes: 8 additions & 4 deletions httomo/transform_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from pathlib import Path
from typing import Optional
from httomo.method_wrappers import make_method_wrapper
from httomo.method_wrappers.datareducer import DatareducerWrapper
from httomo.method_wrappers.generic import GenericMethodWrapper
from httomo.method_wrappers.images import ImagesWrapper
from httomo.method_wrappers.save_intermediate import SaveIntermediateFilesWrapper
from httomo.methods_database.query import MethodDatabaseRepository
from httomo.runner.pipeline import Pipeline
from mpi4py import MPI
Expand Down Expand Up @@ -40,7 +44,7 @@ def insert_save_methods(self, pipeline: Pipeline) -> Pipeline:
and "center" not in m.method_name
):
methods.append(
make_method_wrapper(
SaveIntermediateFilesWrapper(
self._repo,
"httomo.methods",
"save_intermediate_data",
Expand All @@ -59,7 +63,7 @@ def insert_data_reducer(self, pipeline: Pipeline) -> Pipeline:
loader = pipeline.loader
methods = []
methods.append(
make_method_wrapper(
DatareducerWrapper(
self._repo,
"httomolib.misc.morph",
"data_reducer",
Expand All @@ -83,7 +87,7 @@ def insert_save_images_after_sweep(self, pipeline: Pipeline) -> Pipeline:
methods.append(m)
if m.sweep or "recon" in m.module_path and sweep_before:
methods.append(
make_method_wrapper(
GenericMethodWrapper(
self._repo,
"httomolibgpu.misc.rescale",
"rescale_to_int",
Expand All @@ -96,7 +100,7 @@ def insert_save_images_after_sweep(self, pipeline: Pipeline) -> Pipeline:
)
)
methods.append(
make_method_wrapper(
ImagesWrapper(
self._repo,
"httomolib.misc.images",
"save_to_images",
Expand Down
7 changes: 7 additions & 0 deletions httomo/ui_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from httomo.darks_flats import DarksFlatsFileConfig

from httomo.methods_database.query import MethodDatabaseRepository
from httomo.preview import PreviewConfig
from httomo.runner.method_wrapper import MethodWrapper
from httomo.runner.pipeline import Pipeline

Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.tasks_file_path = tasks_file_path
self.in_data_file = in_data_file_path
self.comm = comm
self._preview_config: PreviewConfig | None = None

root, ext = os.path.splitext(self.tasks_file_path)
if ext.upper() in [".YAML", ".YML"]:
Expand Down Expand Up @@ -94,11 +96,15 @@ def _append_methods_list(
method_id_map
map of methods and ids
"""
assert (
self._preview_config is not None
), "Preview config should have been stored prior to method wrapper creation"
method = make_method_wrapper(
method_repository=self.repo,
module_path=task_conf["module_path"],
method_name=task_conf["method"],
comm=self.comm,
preview_config=self._preview_config,
save_result=task_conf.get("save_result", None),
output_mapping=task_conf.get("side_outputs", dict()),
task_id=task_conf.get("id", f"task_{i + 1}"),
Expand Down Expand Up @@ -139,6 +145,7 @@ def _setup_loader(self) -> LoaderInterface:
with h5py.File(in_file, "r") as f:
data_shape = f[data_path].shape
preview = parse_preview(parameters.get("preview", None), data_shape)
self._preview_config = preview

loader = make_loader(
repo=self.repo,
Expand Down
3 changes: 2 additions & 1 deletion tests/method_wrappers/test_data_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from httomo.method_wrappers.datareducer import DatareducerWrapper
from httomo.runner.auxiliary_data import AuxiliaryData
from httomo.runner.dataset import DataSetBlock
from ..testing_utils import make_mock_repo
from ..testing_utils import make_mock_preview_config, make_mock_repo


import numpy as np
Expand All @@ -24,6 +24,7 @@ def data_reducer(x):
"mocked_module_path.morph",
"data_reducer",
MPI.COMM_WORLD,
make_mock_preview_config(mocker),
)
assert isinstance(wrp, DatareducerWrapper)

Expand Down
3 changes: 2 additions & 1 deletion tests/method_wrappers/test_dezinging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from httomo.method_wrappers.dezinging import DezingingWrapper
from httomo.runner.auxiliary_data import AuxiliaryData
from httomo.runner.dataset import DataSetBlock
from ..testing_utils import make_mock_repo
from ..testing_utils import make_mock_preview_config, make_mock_repo


import numpy as np
Expand All @@ -24,6 +24,7 @@ def remove_outlier(x, axis="auto"):
"mocked_module_path.prep",
"remove_outlier",
MPI.COMM_WORLD,
make_mock_preview_config(mocker),
)
assert isinstance(wrp, DezingingWrapper)

Expand Down
89 changes: 89 additions & 0 deletions tests/method_wrappers/test_distortion_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
from mpi4py import MPI
from pytest_mock import MockerFixture

from httomo.method_wrappers.distortion_correction import DistortionCorrectionWrapper
from httomo.preview import PreviewConfig, PreviewDimConfig
from tests.testing_utils import make_mock_repo


@pytest.mark.parametrize(
"method_name, expected_result",
[("distortion_correction", True), ("other_method", False)],
ids=["should-select", "shouldn't-select"],
)
def test_class_only_selected_for_methods_with_distortion_correction_in_name(
method_name: str, expected_result: bool
):
assert (
DistortionCorrectionWrapper.should_select_this_class(
"dummy.module.path", method_name
)
is expected_result
)


def test_requires_preview_is_true():
assert DistortionCorrectionWrapper.requires_preview() is True


@pytest.mark.parametrize(
"preview_config",
[
PreviewConfig(
angles=PreviewDimConfig(start=0, stop=180),
detector_y=PreviewDimConfig(start=0, stop=128),
detector_x=PreviewDimConfig(start=0, stop=160),
),
PreviewConfig(
angles=PreviewDimConfig(start=0, stop=180),
detector_y=PreviewDimConfig(start=5, stop=123),
detector_x=PreviewDimConfig(start=0, stop=160),
),
PreviewConfig(
angles=PreviewDimConfig(start=0, stop=180),
detector_y=PreviewDimConfig(start=0, stop=128),
detector_x=PreviewDimConfig(start=5, stop=155),
),
],
ids=["no_cropping", "crop_det_y_both_ends", "crop_det_x_both_ends"],
)
def test_sets_shiftxy_and_stepxy_params_correctly(
preview_config: PreviewConfig, mocker: MockerFixture
):
MODULE_PATH = "dummy.module.path"
METHOD_NAME = "distortion_correction_dummy"
COMM = MPI.COMM_WORLD

# Patch method function import that occurs when the wrapper object is created, to instead
# import the below dummy method function
class FakeModule:
def distortion_correction_dummy(shift_xy, step_xy): # type: ignore
return shift_xy + step_xy

mocker.patch(
"httomo.method_wrappers.generic.import_module", return_value=FakeModule
)

wrapper = DistortionCorrectionWrapper(
method_repository=make_mock_repo(mocker),
module_path=MODULE_PATH,
method_name=METHOD_NAME,
comm=COMM,
preview_config=preview_config,
)

# Check that the shift and step parameter values are present in the wrapper's parameter
# config, and have the expected values based on the preview config
expected_shift_values = [
preview_config.detector_x.start,
preview_config.detector_y.start,
]
expected_step_values = [1, 1]

SHIFT_PARAM_NAME = "shift_xy"
STEP_PARAM_NAME = "step_xy"
assert SHIFT_PARAM_NAME in wrapper.config_params
assert wrapper.config_params[SHIFT_PARAM_NAME] == expected_shift_values
assert STEP_PARAM_NAME in wrapper.config_params
assert wrapper.config_params[STEP_PARAM_NAME] == expected_step_values
Loading

0 comments on commit 75acffc

Please sign in to comment.