Skip to content

Commit

Permalink
Merge pull request #315 from carterbox/fresnel-propagator
Browse files Browse the repository at this point in the history
NEW: Add Fresnel propagator for multi-slice
  • Loading branch information
a4894z authored Jun 5, 2024
2 parents b72a20b + 26b4c9e commit 30e2a80
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/tike/operators/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .pad import *
from .patch import *
from .propagation import *
from .fresnelspectprop import *
from .multislice import *
from .ptycho import *
from .rotate import *
Expand Down
131 changes: 131 additions & 0 deletions src/tike/operators/cupy/fresnelspectprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Defines a free-space propagation operator based on the CuPy FFT module."""

__author__ = "Ashish Tripathi"
__copyright__ = "Copyright (c) 2024, UChicago Argonne, LLC."

import typing

import numpy.typing as npt
import numpy as np

from .cache import CachedFFT
from .operator import Operator


class FresnelSpectProp(CachedFFT, Operator):
"""Fresnel spectrum propagation (short range) using CuPy.
Take an (..., W, H) compelx array representing a wavefront and propagate.
Attributes
----------
pixel_size : float
The realspace size of a pixel in meters
distance : float
The realspace propagation distance in meters
wavelength : float
The wavelength of the light in meters
Parameters
----------
nearplane: (..., W, H) complex64
The wavefronts before propagation.
farplane: (..., W, H) complex64
The wavefronts after propagation.
"""

def __init__(
self,
norm: str = "ortho",
pixel_size: float = 1e-7,
distance: float = 1e-6,
wavelength: float = 1e-9,
**kwargs,
):
self.norm = norm
self.pixel_size = pixel_size
self.distance = distance
self.wavelength = wavelength

def fwd(
self,
nearplane: npt.NDArray[np.csingle],
overwrite: bool = False,
**kwargs,
) -> npt.NDArray[np.csingle]:
"""forward (parallel to beam direction) Fresnel spectrum propagtion operator"""
propagator = self._create_fresnel_spectrum_propagator(
(nearplane.shape[-2], nearplane.shape[-1]),
self.pixel_size,
self.distance,
self.wavelength,
)

nearplane_fft2 = self._fft2(
nearplane,
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

farplane = self._ifft2(
nearplane_fft2 * propagator,
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

return farplane

def adj(
self,
farplane: npt.NDArray[np.csingle],
overwrite: bool = False,
**kwargs,
) -> npt.NDArray[np.csingle]:
"""backward (anti-parallel to beam direction) Fresnel spectrum propagtion operator"""
propagator = self._create_fresnel_spectrum_propagator(
(farplane.shape[-2], farplane.shape[-1]),
self.pixel_size,
self.distance,
self.wavelength,
)

farplane_fft2 = self._fft2(
farplane,
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

nearplane = self._ifft2(
# FIXME: IS IT OK TO ALWAYS TAKE CONJ? OR SHOULD WE DO THIS ONCE AND REUSE?
farplane_fft2 * self.xp.conj(propagator),
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

return nearplane

def _create_fresnel_spectrum_propagator(
self,
N: typing.Tuple[int, int],
pixel_size: float = 1.0,
distance: float = 1.0,
wavelength: float = 1.0,
) -> np.ndarray:
# FIXME: Check that dimension ordering is consistent
rr2 = self.xp.linspace(-0.5 * N[1], 0.5 * N[1] - 1, num=N[1]) ** 2
cc2 = self.xp.linspace(-0.5 * N[0], 0.5 * N[0] - 1, num=N[0]) ** 2

x = -1j * self.xp.pi * wavelength * distance
rr2 = self.xp.exp(x * rr2[..., None] / (pixel_size**2))
cc2 = self.xp.exp(x * cc2[..., None] / (pixel_size**2))

fresnel_spectrum_propagator = self.xp.ndarray.astype(
self.xp.fft.fftshift(self.xp.outer(self.xp.transpose(rr2), cc2)),
dtype=self.xp.csingle,
)

return fresnel_spectrum_propagator
25 changes: 23 additions & 2 deletions tests/operators/test_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import unittest

import numpy as np
from tike.operators import Propagation
import tike.precision
from tike.operators import Propagation, FresnelSpectProp

from .util import random_complex, OperatorTests

Expand Down Expand Up @@ -37,5 +36,27 @@ def setUp(self, nwaves=13, probe_shape=127):
print(self.operator)


class TestFresnelSpectrumPropagation(unittest.TestCase, OperatorTests):
"""Test the FresnelSpectProp operator."""

def setUp(self, nwaves=13, probe_shape=127):
"""Load a dataset for reconstruction."""
self.operator = FresnelSpectProp(
nwaves=nwaves,
detector_shape=probe_shape,
probe_shape=probe_shape,
)
self.operator.__enter__()
self.xp = self.operator.xp
np.random.seed(0)
self.m = self.xp.asarray(
random_complex(nwaves, probe_shape, probe_shape))
self.m_name = 'nearplane'
self.d = self.xp.asarray(
random_complex(nwaves, probe_shape, probe_shape))
self.d_name = 'farplane'
self.kwargs = {}
print(self.operator)

if __name__ == '__main__':
unittest.main()

0 comments on commit 30e2a80

Please sign in to comment.