Skip to content

Commit

Permalink
Added cuda accelerated PC registration
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed Mar 22, 2024
1 parent 9b6889e commit 2a979db
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 18 deletions.
105 changes: 91 additions & 14 deletions tomotools/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from skimage.feature import canny
from skimage.filters import sobel
import matplotlib.pylab as plt
import astra

has_cupy = True
try:
import cupy as cp
except ImportError:
has_cupy = False

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -314,7 +321,7 @@ def calculate_shifts_com(stack, nslices):
return yshifts


def calculate_shifts_pc(stack, start, show_progressbar, upsample_factor):
def calculate_shifts_pc(stack, start, show_progressbar, upsample_factor, cuda):
"""
Calculate shifts using the phase correlation algorithm.
Expand All @@ -330,18 +337,84 @@ def calculate_shifts_pc(stack, start, show_progressbar, upsample_factor):
The X- and Y-shifts to be applied to each image
"""
shifts = np.zeros((stack.data.shape[0], 2))

with tqdm.tqdm(total=stack.data.shape[0] - 1, desc="Calculating shifts", disable=not show_progressbar) as pbar:
for i in range(start, 0, -1):
shift = pcc(stack.data[i], stack.data[i - 1], upsample_factor=upsample_factor)[0]
shifts[i - 1] = shifts[i] + shift
pbar.update(1)
def _upsampled_dft(data, upsampled_region_size, upsample_factor=1, axis_offsets=None):
upsampled_region_size = [upsampled_region_size,] * data.ndim

im2pi = 1j * 2 * cp.pi

dim_properties = list(zip(data.shape, upsampled_region_size, axis_offsets))

for n_items, ups_size, ax_offset in dim_properties[::-1]:
kernel = (cp.arange(ups_size) - ax_offset)[:, None] * cp.fft.fftfreq(n_items, upsample_factor)
kernel = cp.exp(-im2pi * kernel)
# use kernel with same precision as the data
kernel = kernel.astype(data.dtype, copy=False)
data = cp.tensordot(kernel, data, axes=(1, -1))
return data

def _cupy_phase_correlate(ref_cp, mov_cp, upsample_factor):
ref_fft = cp.fft.fftn(ref_cp)
mov_fft = cp.fft.fftn(mov_cp)

cross_power_spectrum = ref_fft * mov_fft.conj()
eps = cp.finfo(cross_power_spectrum.real.dtype).eps
cross_power_spectrum /= cp.maximum(cp.abs(cross_power_spectrum), 100 * eps)
phase_correlation = cp.fft.ifft2(cross_power_spectrum)

maxima = cp.unravel_index(cp.argmax(cp.abs(phase_correlation)), phase_correlation.shape)
midpoint = cp.array([cp.fix(axis_size / 2) for axis_size in shape])

float_dtype = cross_power_spectrum.real.dtype

shift = cp.stack(maxima).astype(float_dtype, copy=False)
shift[shift > midpoint] -= cp.array(shape)[shift > midpoint]

if upsample_factor > 1:
upsample_factor = cp.array(upsample_factor, dtype=float_dtype)
upsampled_region_size = cp.ceil(upsample_factor * 1.5)
dftshift = cp.fix(upsampled_region_size / 2.0)

shift = cp.round(shift * upsample_factor) / upsample_factor

sample_region_offset = dftshift - shift * upsample_factor
phase_correlation = _upsampled_dft(cross_power_spectrum.conj(), upsampled_region_size, upsample_factor, sample_region_offset).conj()
maxima = np.unravel_index(cp.argmax(np.abs(phase_correlation)), phase_correlation.shape)

maxima = cp.stack(maxima).astype(float_dtype, copy=False)
maxima -= dftshift

shift += maxima / upsample_factor
return shift

if has_cupy and astra.use_cuda() and cuda:
stack_cp = cp.array(stack.data)
shifts = cp.zeros([stack_cp.shape[0], 2])
ref_cp = stack_cp[0]
ref_fft = cp.fft.fftn(ref_cp)
shape = ref_fft.shape
with tqdm.tqdm(total=stack.data.shape[0] - 1, desc="Calculating shifts", disable=not show_progressbar) as pbar:
for i in range(start, 0, -1):
shift = _cupy_phase_correlate(stack_cp[i], stack_cp[i - 1], upsample_factor=upsample_factor)
shifts[i - 1] = shifts[i] + shift
pbar.update(1)
for i in range(start, stack.data.shape[0] - 1):
shift = _cupy_phase_correlate(stack_cp[i], stack_cp[i + 1], upsample_factor=upsample_factor)
shifts[i + 1] = shifts[i] + shift
pbar.update(1)
shifts = shifts.get()

for i in range(start, stack.data.shape[0] - 1):
shift = pcc(stack.data[i], stack.data[i + 1], upsample_factor=upsample_factor)[0]
shifts[i + 1] = shifts[i] + shift
pbar.update(1)
else:
shifts = np.zeros((stack.data.shape[0], 2))
with tqdm.tqdm(total=stack.data.shape[0] - 1, desc="Calculating shifts", disable=not show_progressbar) as pbar:
for i in range(start, 0, -1):
shift = pcc(stack.data[i], stack.data[i - 1], upsample_factor=upsample_factor)[0]
shifts[i - 1] = shifts[i] + shift
pbar.update(1)

for i in range(start, stack.data.shape[0] - 1):
shift = pcc(stack.data[i], stack.data[i + 1], upsample_factor=upsample_factor)[0]
shifts[i + 1] = shifts[i] + shift
pbar.update(1)

return shifts

Expand Down Expand Up @@ -510,9 +583,13 @@ def align_stack(stack, method, start, show_progressbar, **kwargs):
shifts[:, 1] = calculate_shifts_conservation_of_mass(stack, xrange, p)
shifts[:, 0] = calculate_shifts_com(stack, nslices)
elif method.lower() == 'pc':
cuda = kwargs.get('cuda', False)
upsample_factor = kwargs.get('upsample_factor', 3)
logger.info("Performing stack registration using phase correlation")
shifts = calculate_shifts_pc(stack, start, show_progressbar, upsample_factor)
if cuda:
logger.info("Performing stack registration using CUDA-accelerated phase correlation")
else:
logger.info("Performing stack registration using phase correlation")
shifts = calculate_shifts_pc(stack, start, show_progressbar, upsample_factor, cuda)
elif method.lower() in ["stackreg", 'sr']:
logger.info("Performing stack registration using PyStackReg")
shifts = calculate_shifts_stackreg(stack, start, show_progressbar)
Expand Down
15 changes: 15 additions & 0 deletions tomotools/tests/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tomotools
import numpy
from tomotools import recon
from tomotools.base import TomoStack
import astra
import pytest

Expand Down Expand Up @@ -69,3 +70,17 @@ def test_run_sirt_cuda(self):
assert rec.data.shape == (1, slices.data.shape[1], slices.data.shape[1])
assert rec.data.shape[0] == slices.data.shape[2]
assert type(rec) is numpy.ndarray


@pytest.mark.skipif(not astra.use_cuda(), reason="CUDA not detected")
class TestStackRegisterCUDA:
def test_register_pc_cuda(self):
stack = ds.get_needle_data()
stack.metadata.Tomography.shifts = \
stack.metadata.Tomography.shifts[0:20]
reg = stack.inav[0:20].stack_register('PC', cuda=True)
assert type(reg) is TomoStack
assert reg.axes_manager.signal_shape == \
stack.inav[0:20].axes_manager.signal_shape
assert reg.axes_manager.navigation_shape == \
stack.inav[0:20].axes_manager.navigation_shape
5 changes: 1 addition & 4 deletions tomotools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,7 @@ def filter_stack(stack, filter_name="shepp-logan", cutoff=0.5):
pass
elif filter_name == "shepp-logan":
filter[1:] = filter[1:] * np.sinc(omega[1:] / (2 * np.pi))
elif filter_name in [
"hanning",
"hann",
]:
elif filter_name in ["hanning", "hann",]:
filter[1:] = filter[1:] * (1 + np.cos(omega[1:])) / 2
elif filter_name in [
"cosine",
Expand Down

0 comments on commit 2a979db

Please sign in to comment.