diff --git a/tomotools/align.py b/tomotools/align.py index 190d0029..ea730f46 100644 --- a/tomotools/align.py +++ b/tomotools/align.py @@ -19,6 +19,13 @@ from skimage.feature import canny from skimage.filters import sobel import matplotlib.pylab as plt +import astra + +has_cupy = True +try: + import cupy as cp +except ImportError: + has_cupy = False logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -314,7 +321,7 @@ def calculate_shifts_com(stack, nslices): return yshifts -def calculate_shifts_pc(stack, start, show_progressbar, upsample_factor): +def calculate_shifts_pc(stack, start, show_progressbar, upsample_factor, cuda): """ Calculate shifts using the phase correlation algorithm. @@ -330,18 +337,84 @@ def calculate_shifts_pc(stack, start, show_progressbar, upsample_factor): The X- and Y-shifts to be applied to each image """ - shifts = np.zeros((stack.data.shape[0], 2)) - - with tqdm.tqdm(total=stack.data.shape[0] - 1, desc="Calculating shifts", disable=not show_progressbar) as pbar: - for i in range(start, 0, -1): - shift = pcc(stack.data[i], stack.data[i - 1], upsample_factor=upsample_factor)[0] - shifts[i - 1] = shifts[i] + shift - pbar.update(1) + def _upsampled_dft(data, upsampled_region_size, upsample_factor=1, axis_offsets=None): + upsampled_region_size = [upsampled_region_size,] * data.ndim + + im2pi = 1j * 2 * cp.pi + + dim_properties = list(zip(data.shape, upsampled_region_size, axis_offsets)) + + for n_items, ups_size, ax_offset in dim_properties[::-1]: + kernel = (cp.arange(ups_size) - ax_offset)[:, None] * cp.fft.fftfreq(n_items, upsample_factor) + kernel = cp.exp(-im2pi * kernel) + # use kernel with same precision as the data + kernel = kernel.astype(data.dtype, copy=False) + data = cp.tensordot(kernel, data, axes=(1, -1)) + return data + + def _cupy_phase_correlate(ref_cp, mov_cp, upsample_factor): + ref_fft = cp.fft.fftn(ref_cp) + mov_fft = cp.fft.fftn(mov_cp) + + cross_power_spectrum = ref_fft * mov_fft.conj() + eps = cp.finfo(cross_power_spectrum.real.dtype).eps + cross_power_spectrum /= cp.maximum(cp.abs(cross_power_spectrum), 100 * eps) + phase_correlation = cp.fft.ifft2(cross_power_spectrum) + + maxima = cp.unravel_index(cp.argmax(cp.abs(phase_correlation)), phase_correlation.shape) + midpoint = cp.array([cp.fix(axis_size / 2) for axis_size in shape]) + + float_dtype = cross_power_spectrum.real.dtype + + shift = cp.stack(maxima).astype(float_dtype, copy=False) + shift[shift > midpoint] -= cp.array(shape)[shift > midpoint] + + if upsample_factor > 1: + upsample_factor = cp.array(upsample_factor, dtype=float_dtype) + upsampled_region_size = cp.ceil(upsample_factor * 1.5) + dftshift = cp.fix(upsampled_region_size / 2.0) + + shift = cp.round(shift * upsample_factor) / upsample_factor + + sample_region_offset = dftshift - shift * upsample_factor + phase_correlation = _upsampled_dft(cross_power_spectrum.conj(), upsampled_region_size, upsample_factor, sample_region_offset).conj() + maxima = np.unravel_index(cp.argmax(np.abs(phase_correlation)), phase_correlation.shape) + + maxima = cp.stack(maxima).astype(float_dtype, copy=False) + maxima -= dftshift + + shift += maxima / upsample_factor + return shift + + if has_cupy and astra.use_cuda() and cuda: + stack_cp = cp.array(stack.data) + shifts = cp.zeros([stack_cp.shape[0], 2]) + ref_cp = stack_cp[0] + ref_fft = cp.fft.fftn(ref_cp) + shape = ref_fft.shape + with tqdm.tqdm(total=stack.data.shape[0] - 1, desc="Calculating shifts", disable=not show_progressbar) as pbar: + for i in range(start, 0, -1): + shift = _cupy_phase_correlate(stack_cp[i], stack_cp[i - 1], upsample_factor=upsample_factor) + shifts[i - 1] = shifts[i] + shift + pbar.update(1) + for i in range(start, stack.data.shape[0] - 1): + shift = _cupy_phase_correlate(stack_cp[i], stack_cp[i + 1], upsample_factor=upsample_factor) + shifts[i + 1] = shifts[i] + shift + pbar.update(1) + shifts = shifts.get() - for i in range(start, stack.data.shape[0] - 1): - shift = pcc(stack.data[i], stack.data[i + 1], upsample_factor=upsample_factor)[0] - shifts[i + 1] = shifts[i] + shift - pbar.update(1) + else: + shifts = np.zeros((stack.data.shape[0], 2)) + with tqdm.tqdm(total=stack.data.shape[0] - 1, desc="Calculating shifts", disable=not show_progressbar) as pbar: + for i in range(start, 0, -1): + shift = pcc(stack.data[i], stack.data[i - 1], upsample_factor=upsample_factor)[0] + shifts[i - 1] = shifts[i] + shift + pbar.update(1) + + for i in range(start, stack.data.shape[0] - 1): + shift = pcc(stack.data[i], stack.data[i + 1], upsample_factor=upsample_factor)[0] + shifts[i + 1] = shifts[i] + shift + pbar.update(1) return shifts @@ -510,9 +583,13 @@ def align_stack(stack, method, start, show_progressbar, **kwargs): shifts[:, 1] = calculate_shifts_conservation_of_mass(stack, xrange, p) shifts[:, 0] = calculate_shifts_com(stack, nslices) elif method.lower() == 'pc': + cuda = kwargs.get('cuda', False) upsample_factor = kwargs.get('upsample_factor', 3) - logger.info("Performing stack registration using phase correlation") - shifts = calculate_shifts_pc(stack, start, show_progressbar, upsample_factor) + if cuda: + logger.info("Performing stack registration using CUDA-accelerated phase correlation") + else: + logger.info("Performing stack registration using phase correlation") + shifts = calculate_shifts_pc(stack, start, show_progressbar, upsample_factor, cuda) elif method.lower() in ["stackreg", 'sr']: logger.info("Performing stack registration using PyStackReg") shifts = calculate_shifts_stackreg(stack, start, show_progressbar) diff --git a/tomotools/tests/test_cuda.py b/tomotools/tests/test_cuda.py index cf0338fd..c16fdfc6 100644 --- a/tomotools/tests/test_cuda.py +++ b/tomotools/tests/test_cuda.py @@ -3,6 +3,7 @@ import tomotools import numpy from tomotools import recon +from tomotools.base import TomoStack import astra import pytest @@ -69,3 +70,17 @@ def test_run_sirt_cuda(self): assert rec.data.shape == (1, slices.data.shape[1], slices.data.shape[1]) assert rec.data.shape[0] == slices.data.shape[2] assert type(rec) is numpy.ndarray + + +@pytest.mark.skipif(not astra.use_cuda(), reason="CUDA not detected") +class TestStackRegisterCUDA: + def test_register_pc_cuda(self): + stack = ds.get_needle_data() + stack.metadata.Tomography.shifts = \ + stack.metadata.Tomography.shifts[0:20] + reg = stack.inav[0:20].stack_register('PC', cuda=True) + assert type(reg) is TomoStack + assert reg.axes_manager.signal_shape == \ + stack.inav[0:20].axes_manager.signal_shape + assert reg.axes_manager.navigation_shape == \ + stack.inav[0:20].axes_manager.navigation_shape diff --git a/tomotools/utils.py b/tomotools/utils.py index 3f6e3881..bbd699a5 100644 --- a/tomotools/utils.py +++ b/tomotools/utils.py @@ -337,10 +337,7 @@ def filter_stack(stack, filter_name="shepp-logan", cutoff=0.5): pass elif filter_name == "shepp-logan": filter[1:] = filter[1:] * np.sinc(omega[1:] / (2 * np.pi)) - elif filter_name in [ - "hanning", - "hann", - ]: + elif filter_name in ["hanning", "hann",]: filter[1:] = filter[1:] * (1 + np.cos(omega[1:])) / 2 elif filter_name in [ "cosine",