Skip to content

Commit

Permalink
NEW: Add Fresnel Propagation Operator
Browse files Browse the repository at this point in the history
Co-authored-by: Ashish Tripathi <[email protected]>
  • Loading branch information
2 people authored and carterbox committed May 29, 2024
1 parent 54895ec commit 4802e50
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/tike/operators/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .pad import *
from .patch import *
from .propagation import *
from .fresnelspectprop import *
from .ptycho import *
from .rotate import *
from .shift import *
143 changes: 143 additions & 0 deletions src/tike/operators/cupy/fresnelspectprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Defines a free-space propagation operator based on the CuPy FFT module."""

__author__ = "Ashish Tripathi"
__copyright__ = "Copyright (c) 2024, UChicago Argonne, LLC."

import typing

import numpy.typing as npt
import numpy as np

from .cache import CachedFFT
from .operator import Operator


class FresnelSpectProp(CachedFFT, Operator):
"""Fresnel spectrum propagation (short range) using CuPy.
Take an (..., N, N) array and apply the Fourier transform to the last two
dimensions.
Attributes
----------
pixel_size : float
The realspace size of a pixel in meters
delta_z : float
The realspace propagation distance in meters
wavelength : float
The wavelength of the light in meters
Parameters
----------
nearplane: (..., detector_shape, detector_shape) complex64
The wavefronts before propagation.
farplane: (..., detector_shape, detector_shape) complex64
The wavefronts after propagation.
"""

def __init__(
self,
norm: str = "ortho",
pixel_size: float = 1.0,
delta_z: float = 1.0,
wavelength: float = 1.0,
**kwargs,
):
self.norm = norm
self.pixel_size = pixel_size
self.delta_z = delta_z
self.wavelength = wavelength

def fwd(
self,
nearplane: npt.NDArray[np.csingle],
overwrite: bool = False,
**kwargs,
) -> npt.NDArray[np.csingle]:
"""forward (parallel to beam direction) Fresnel spectrum propagtion operator"""
propagator = self._create_fresnel_spectrum_propagator(
(nearplane.shape[-2], nearplane.shape[-1]),
self.pixel_size,
self.delta_z,
self.wavelength,
)

nearplane_fft2 = self._fft2(
nearplane,
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

farplane = self._ifft2(
nearplane_fft2 * propagator,
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

return farplane

def adj(
self,
farplane: npt.NDArray[np.csingle],
overwrite: bool = False,
**kwargs,
) -> npt.NDArray[np.csingle]:
"""backward (anti-parallel to beam direction) Fresnel spectrum propagtion operator"""
propagator = self._create_fresnel_spectrum_propagator(
(farplane.shape[-2], farplane.shape[-1]),
self.pixel_size,
self.delta_z,
self.wavelength,
)

farplane_fft2 = self._fft2(
farplane,
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

nearplane = self._ifft2(
farplane_fft2
* self.xp.conj(
propagator,
), # IS IT OK TO ALWAYS TAKE CONJ? OR SHOULD WE DO THIS ONCE AND REUSE?
norm=self.norm,
axes=(-2, -1),
overwrite_x=overwrite,
)

return nearplane

def _create_fresnel_spectrum_propagator(
self,
N: typing.Tuple[int, int],
pixel_size: float = 1.0,
delta_z: float = 1.0,
wavelength: float = 1.0,
) -> np.ndarray:
"""
Parameters
----------
pixel_size : real width of pixel in meters
delta_z: propagation distance in meters
wavelength: wavelength of light in meters
"""
# FIXME: Check that dimension ordering is consistent
rr2 = self.xp.linspace(-0.5 * N[1], 0.5 * N[1] - 1, num=N[1]) ** 2
cc2 = self.xp.linspace(-0.5 * N[0], 0.5 * N[0] - 1, num=N[0]) ** 2

prb_FOV = self.xp.asarray([pixel_size, pixel_size], dtype=self.xp.float32)

x = -1j * self.xp.pi * wavelength * delta_z
rr2 = self.xp.exp(x * rr2[..., None] / (prb_FOV[0] ** 2))
cc2 = self.xp.exp(x * cc2[..., None] / (prb_FOV[1] ** 2))

fresnel_spectrum_propagator = self.xp.ndarray.astype(
self.xp.fft.fftshift(self.xp.outer(self.xp.transpose(rr2), cc2)),
dtype=self.xp.csingle,
)

return fresnel_spectrum_propagator
25 changes: 23 additions & 2 deletions tests/operators/test_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import unittest

import numpy as np
from tike.operators import Propagation
import tike.precision
from tike.operators import Propagation, FresnelSpectProp

from .util import random_complex, OperatorTests

Expand Down Expand Up @@ -37,5 +36,27 @@ def setUp(self, nwaves=13, probe_shape=127):
print(self.operator)


class TestFresnelSpectrumPropagation(unittest.TestCase, OperatorTests):
"""Test the FresnelSpectProp operator."""

def setUp(self, nwaves=13, probe_shape=127):
"""Load a dataset for reconstruction."""
self.operator = FresnelSpectProp(
nwaves=nwaves,
detector_shape=probe_shape,
probe_shape=probe_shape,
)
self.operator.__enter__()
self.xp = self.operator.xp
np.random.seed(0)
self.m = self.xp.asarray(
random_complex(nwaves, probe_shape, probe_shape))
self.m_name = 'nearplane'
self.d = self.xp.asarray(
random_complex(nwaves, probe_shape, probe_shape))
self.d_name = 'farplane'
self.kwargs = {}
print(self.operator)

if __name__ == '__main__':
unittest.main()

0 comments on commit 4802e50

Please sign in to comment.