Skip to content

Commit

Permalink
Merge pull request #331 from a4894z/main
Browse files Browse the repository at this point in the history
probe options and object options now include relevant Fresnel spectrum propagation multislice parameters and are used in the fwd and adj operators
  • Loading branch information
a4894z authored Aug 10, 2024
2 parents 519a331 + 5c403b3 commit 3def223
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 16 deletions.
85 changes: 73 additions & 12 deletions src/tike/operators/cupy/fresnelspectprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def __init__(
self,
norm: str = "ortho",
pixel_size: float = 1e-7,
probe_FOV: typing.Tuple[float, float] = ( 1e-6, 1e-6 ),
distance: float = 1e-6,
wavelength: float = 1e-9,
**kwargs,
):
self.norm = norm
self.pixel_size = pixel_size
self.probe_FOV = probe_FOV
self.distance = distance
self.wavelength = wavelength

Expand All @@ -56,7 +58,7 @@ def fwd(
"""forward (parallel to beam direction) Fresnel spectrum propagtion operator"""
propagator = self._create_fresnel_spectrum_propagator(
(nearplane.shape[-2], nearplane.shape[-1]),
self.pixel_size,
self.probe_FOV,
self.distance,
self.wavelength,
)
Expand Down Expand Up @@ -86,7 +88,7 @@ def adj(
"""backward (anti-parallel to beam direction) Fresnel spectrum propagtion operator"""
propagator = self._create_fresnel_spectrum_propagator(
(farplane.shape[-2], farplane.shape[-1]),
self.pixel_size,
self.probe_FOV,
self.distance,
self.wavelength,
)
Expand All @@ -111,21 +113,80 @@ def adj(
def _create_fresnel_spectrum_propagator(
self,
N: typing.Tuple[int, int],
pixel_size: float = 1.0,
probe_FOV: float = 1.0,
distance: float = 1.0,
wavelength: float = 1.0,
) -> np.ndarray:

# FIXME: Check that dimension ordering is consistent
rr2 = self.xp.linspace(-0.5 * N[1], 0.5 * N[1] - 1, num=N[1]) ** 2
cc2 = self.xp.linspace(-0.5 * N[0], 0.5 * N[0] - 1, num=N[0]) ** 2

x = -1j * self.xp.pi * wavelength * distance
rr2 = self.xp.exp(x * rr2[..., None] / (pixel_size**2))
cc2 = self.xp.exp(x * cc2[..., None] / (pixel_size**2))
xgrid = ( 0.5 + self.xp.linspace( ( -0.5 * N[1] ), ( 0.5 * N[1] - 1 ), num = N[1] )) / N[1]
ygrid = ( 0.5 + self.xp.linspace( ( -0.5 * N[0] ), ( 0.5 * N[0] - 1 ), num = N[0] )) / N[0]

fresnel_spectrum_propagator = self.xp.ndarray.astype(
self.xp.fft.fftshift(self.xp.outer(self.xp.transpose(rr2), cc2)),
dtype=self.xp.csingle,
)
kx = 2 * self.xp.pi * N[0] * xgrid / probe_FOV[ 0 ]
ky = 2 * self.xp.pi * N[1] * ygrid / probe_FOV[ 1 ]

Kx, Ky = self.xp.meshgrid(kx, ky, indexing='xy')

fresnel_spectrum_propagator = self.xp.exp( 1j * distance * self.xp.sqrt( ( 2 * self.xp.pi / wavelength ) ** 2 - Kx ** 2 - Ky ** 2 ))

fresnel_spectrum_propagator = self.xp.ndarray.astype( self.xp.fft.fftshift( fresnel_spectrum_propagator ), dtype = self.xp.csingle )

return fresnel_spectrum_propagator


'''
def create_fresnel_spectrum_propagator(
N: np.ndarray, # probe dimensions ( WIDE, HIGH )
beam_energy: float, # x-ray energy ( eV )
delta_z: float, # meters
detector_dist: float, # meters
detector_pixel_width: float ) -> np.ndarray: # meters
wavelength = ( 1.23984193e-9 / ( beam_energy / 1e3 )) # x-ray energy ( eV ), wavelength ( meters )
xgrid = ( 0.5 + np.linspace( ( -0.5 * N[1] ), ( 0.5 * N[1] - 1 ), num = N[1] )) / N[1]
ygrid = ( 0.5 + np.linspace( ( -0.5 * N[0] ), ( 0.5 * N[0] - 1 ), num = N[0] )) / N[0]
x = wavelength * detector_dist / detector_pixel_width
#z_obj_L = np.asarray( [ x, x ], dtype = np.float32 )
z_obj_L = np.asarray( [ x, x ], dtype = np.float64 )
kx = 2 * np.pi * N[0] * xgrid / z_obj_L[ 0 ]
ky = 2 * np.pi * N[1] * ygrid / z_obj_L[ 1 ]
Kx, Ky = np.meshgrid(kx, ky, indexing='xy')
fresnel_spectrum_propagator = np.exp( 1j * delta_z * np.sqrt( ( 2 * np.pi / wavelength ) ** 2 - Kx ** 2 - Ky ** 2 ))
fresnel_spectrum_propagator = np.ndarray.astype( np.fft.fftshift( fresnel_spectrum_propagator ), dtype = np.csingle )
return fresnel_spectrum_propagator
'''


# def _create_fresnel_spectrum_propagator(
# self,
# N: typing.Tuple[int, int],
# pixel_size: float = 1.0,
# distance: float = 1.0,
# wavelength: float = 1.0,
# ) -> np.ndarray:
# # FIXME: Check that dimension ordering is consistent
# rr2 = self.xp.linspace(-0.5 * N[1], 0.5 * N[1] - 1, num=N[1]) ** 2
# cc2 = self.xp.linspace(-0.5 * N[0], 0.5 * N[0] - 1, num=N[0]) ** 2

# x = -1j * self.xp.pi * wavelength * distance
# rr2 = self.xp.exp(x * rr2[..., None] / (pixel_size**2))
# cc2 = self.xp.exp(x * cc2[..., None] / (pixel_size**2))

# fresnel_spectrum_propagator = self.xp.ndarray.astype(
# self.xp.fft.fftshift(self.xp.outer(self.xp.transpose(rr2), cc2)),
# dtype=self.xp.csingle,
# )

# return fresnel_spectrum_propagator



14 changes: 12 additions & 2 deletions src/tike/operators/cupy/multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ def __init__(
self,
detector_shape: int,
probe_shape: int,
probe_wavelength: float,
probe_FOV_lengths: typing.Tuple[float, float],
nz: int,
n: int,
multislice_propagation_distance: float,
propagation: typing.Type[Propagation] = FresnelSpectProp,
diffraction: typing.Type[Convolution] = Convolution,
norm: str = "ortho",
Expand All @@ -37,7 +40,11 @@ def __init__(
**kwargs,
)
self.propagation = propagation(
detector_shape=detector_shape,
norm = norm,
probe_shape=probe_shape,
wavelength=probe_wavelength,
probe_FOV=probe_FOV_lengths,
distance=multislice_propagation_distance,
**kwargs,
)

Expand All @@ -46,7 +53,10 @@ def __init__(
self.detector_shape = detector_shape
self.nz = nz
self.n = n

self.probe_wavelength = probe_wavelength
self.probe_FOV_lengths = probe_FOV_lengths
self.multislice_propagation_distance = multislice_propagation_distance

def __enter__(self):
self.propagation.__enter__()
self.diffraction.__enter__()
Expand Down
11 changes: 10 additions & 1 deletion src/tike/operators/cupy/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def __init__(
self,
detector_shape: int,
probe_shape: int,
probe_wavelength: float,
probe_FOV_lengths: typing.Tuple[float, float],
nz: int,
n: int,
multislice_propagation_distance: float,
propagation: typing.Type[Propagation] = Propagation,
diffraction: typing.Type[Multislice] = Multislice,
norm: str = 'ortho',
Expand All @@ -82,17 +85,23 @@ def __init__(
)
self.diffraction = diffraction(
probe_shape=probe_shape,
probe_wavelength=probe_wavelength,
probe_FOV_lengths=probe_FOV_lengths,
detector_shape=detector_shape,
nz=nz,
n=n,
multislice_propagation_distance=multislice_propagation_distance,
**kwargs,
)
# TODO: Replace these with @property functions
self.probe_shape = probe_shape
self.detector_shape = detector_shape
self.nz = nz
self.n = n

self.probe_wavelength = probe_wavelength
self.probe_FOV_lengths = probe_FOV_lengths
self.multislice_propagation_distance = multislice_propagation_distance

def __enter__(self):
self.propagation.__enter__()
self.diffraction.__enter__()
Expand Down
12 changes: 11 additions & 1 deletion src/tike/ptycho/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ class ObjectOptions:

clip_magnitude: bool = False
"""Whether to force the object magnitude to remain <= 1."""


multislice_propagation_distance: float = 1.0e-9
""" If we're doing multislice ptychography, then this will be
the slice to slice distance along the probing wavefield propagation
direction. Units are meters.
"""

def copy_to_device(self) -> ObjectOptions:
"""Copy to the current GPU memory."""
options = ObjectOptions(
Expand All @@ -84,6 +90,7 @@ def copy_to_device(self) -> ObjectOptions:
vdecay=self.vdecay,
mdecay=self.mdecay,
clip_magnitude=self.clip_magnitude,
multislice_propagation_distance=self.multislice_propagation_distance,
)
options.update_mnorm = copy.copy(self.update_mnorm)
if self.v is not None:
Expand Down Expand Up @@ -117,6 +124,7 @@ def copy_to_host(self) -> ObjectOptions:
vdecay=self.vdecay,
mdecay=self.mdecay,
clip_magnitude=self.clip_magnitude,
multislice_propagation_distance=self.multislice_propagation_distance,
)
options.update_mnorm = copy.copy(self.update_mnorm)
if self.v is not None:
Expand All @@ -137,6 +145,7 @@ def resample(self, factor: float, interp) -> ObjectOptions:
vdecay=self.vdecay,
mdecay=self.mdecay,
clip_magnitude=self.clip_magnitude,
multislice_propagation_distance=self.multislice_propagation_distance,
)
options.update_mnorm = copy.copy(self.update_mnorm)
return options
Expand Down Expand Up @@ -172,6 +181,7 @@ def join(
vdecay=x[0].vdecay,
mdecay=x[0].mdecay,
clip_magnitude=x[0].clip_magnitude,
multislice_propagation_distance=x[0].multislice_propagation_distance,
)
options.update_mnorm = copy.copy(x[0].update_mnorm)
if x[0].v is not None:
Expand Down
17 changes: 17 additions & 0 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ class ProbeOptions:
it will default to the average of the measurement intensity scaling.
"""

probe_wavelength: float = np.nan
""" The wavelength (meters) of the probing wavefield;
we assume a monochomatic (single wavelength) probe.
"""

probe_FOV_lengths: typing.Tuple[float, float] = ( np.nan, np.nan )
""" The transverse field of view (FOV) for the probe in
units of length (meters). The first element is vertical FOV,
the second element is horizontal FOV.
"""

force_orthogonality: bool = False
"""Forces probes to be orthogonal each iteration."""

Expand Down Expand Up @@ -164,6 +175,8 @@ def copy_to_device(self) -> ProbeOptions:
update_period=self.update_period,
init_rescale_from_measurements=self.init_rescale_from_measurements,
probe_photons=self.probe_photons,
probe_wavelength=self.probe_wavelength,
probe_FOV_lengths=self.probe_FOV_lengths,
force_orthogonality=self.force_orthogonality,
force_centered_intensity=self.force_centered_intensity,
force_sparsity=self.force_sparsity,
Expand Down Expand Up @@ -206,6 +219,8 @@ def copy_to_host(self) -> ProbeOptions:
update_period=self.update_period,
init_rescale_from_measurements=self.init_rescale_from_measurements,
probe_photons=self.probe_photons,
probe_wavelength=self.probe_wavelength,
probe_FOV_lengths=self.probe_FOV_lengths,
force_orthogonality=self.force_orthogonality,
force_centered_intensity=self.force_centered_intensity,
force_sparsity=self.force_sparsity,
Expand Down Expand Up @@ -235,6 +250,8 @@ def resample(self, factor: float, interp) -> ProbeOptions:
update_period=self.update_period,
init_rescale_from_measurements=self.init_rescale_from_measurements,
probe_photons=self.probe_photons,
probe_wavelength=self.probe_wavelength,
probe_FOV_lengths=self.probe_FOV_lengths,
force_orthogonality=self.force_orthogonality,
force_centered_intensity=self.force_centered_intensity,
force_sparsity=self.force_sparsity,
Expand Down
3 changes: 3 additions & 0 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def __init__(
nz=parameters.psi.shape[-2],
n=parameters.psi.shape[-1],
norm=parameters.exitwave_options.propagation_normalization,
probe_wavelength=parameters.probe_options.probe_wavelength,
probe_FOV_lengths=parameters.probe_options.probe_FOV_lengths,
multislice_propagation_distance=parameters.object_options.multislice_propagation_distance,
)
self.comm = tike.communicators.Comm(num_gpu, mpi)

Expand Down
3 changes: 3 additions & 0 deletions tests/operators/test_multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ def setUp(self, depth=7, pw=15, nscan=27):
self.operator = Multislice(
nscan=self.scan_shape[-2],
probe_shape=self.probe_shape[-1],
probe_wavelength = 1e-10,
probe_FOV_lengths = (1e-5, 1e-5),
detector_shape=self.detector_shape[-1],
nz=self.original_shape[-2],
n=self.original_shape[-1],
multislice_propagation_distance = 1e-8,
)
self.operator.__enter__()
self.xp = self.operator.xp
Expand Down

0 comments on commit 3def223

Please sign in to comment.