From a4b0f09d8bff83621fd1210bc7998463d8a3c0c0 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 31 Jul 2024 13:29:49 -0500 Subject: [PATCH 1/2] 1) use of foldslice fresnel propagator, 2) pass parameters for fresnel spectrum propagation to the function so that the defaults aren't used --- src/tike/operators/cupy/fresnelspectprop.py | 85 ++++++++++++++++++--- src/tike/operators/cupy/multislice.py | 14 +++- src/tike/operators/cupy/ptycho.py | 11 ++- src/tike/ptycho/object.py | 12 ++- src/tike/ptycho/probe.py | 17 +++++ src/tike/ptycho/ptycho.py | 3 + 6 files changed, 126 insertions(+), 16 deletions(-) diff --git a/src/tike/operators/cupy/fresnelspectprop.py b/src/tike/operators/cupy/fresnelspectprop.py index abeb9f85..21a67c36 100644 --- a/src/tike/operators/cupy/fresnelspectprop.py +++ b/src/tike/operators/cupy/fresnelspectprop.py @@ -38,12 +38,14 @@ def __init__( self, norm: str = "ortho", pixel_size: float = 1e-7, + probe_FOV: typing.Tuple[float, float] = ( 1e-6, 1e-6 ), distance: float = 1e-6, wavelength: float = 1e-9, **kwargs, ): self.norm = norm self.pixel_size = pixel_size + self.probe_FOV = probe_FOV self.distance = distance self.wavelength = wavelength @@ -56,7 +58,7 @@ def fwd( """forward (parallel to beam direction) Fresnel spectrum propagtion operator""" propagator = self._create_fresnel_spectrum_propagator( (nearplane.shape[-2], nearplane.shape[-1]), - self.pixel_size, + self.probe_FOV, self.distance, self.wavelength, ) @@ -86,7 +88,7 @@ def adj( """backward (anti-parallel to beam direction) Fresnel spectrum propagtion operator""" propagator = self._create_fresnel_spectrum_propagator( (farplane.shape[-2], farplane.shape[-1]), - self.pixel_size, + self.probe_FOV, self.distance, self.wavelength, ) @@ -111,21 +113,80 @@ def adj( def _create_fresnel_spectrum_propagator( self, N: typing.Tuple[int, int], - pixel_size: float = 1.0, + probe_FOV: float = 1.0, distance: float = 1.0, wavelength: float = 1.0, ) -> np.ndarray: + # FIXME: Check that dimension ordering is consistent - rr2 = self.xp.linspace(-0.5 * N[1], 0.5 * N[1] - 1, num=N[1]) ** 2 - cc2 = self.xp.linspace(-0.5 * N[0], 0.5 * N[0] - 1, num=N[0]) ** 2 - x = -1j * self.xp.pi * wavelength * distance - rr2 = self.xp.exp(x * rr2[..., None] / (pixel_size**2)) - cc2 = self.xp.exp(x * cc2[..., None] / (pixel_size**2)) + xgrid = ( 0.5 + self.xp.linspace( ( -0.5 * N[1] ), ( 0.5 * N[1] - 1 ), num = N[1] )) / N[1] + ygrid = ( 0.5 + self.xp.linspace( ( -0.5 * N[0] ), ( 0.5 * N[0] - 1 ), num = N[0] )) / N[0] - fresnel_spectrum_propagator = self.xp.ndarray.astype( - self.xp.fft.fftshift(self.xp.outer(self.xp.transpose(rr2), cc2)), - dtype=self.xp.csingle, - ) + kx = 2 * self.xp.pi * N[0] * xgrid / probe_FOV[ 0 ] + ky = 2 * self.xp.pi * N[1] * ygrid / probe_FOV[ 1 ] + + Kx, Ky = self.xp.meshgrid(kx, ky, indexing='xy') + + fresnel_spectrum_propagator = self.xp.exp( 1j * distance * self.xp.sqrt( ( 2 * self.xp.pi / wavelength ) ** 2 - Kx ** 2 - Ky ** 2 )) + + fresnel_spectrum_propagator = self.xp.ndarray.astype( self.xp.fft.fftshift( fresnel_spectrum_propagator ), dtype = self.xp.csingle ) return fresnel_spectrum_propagator + + +''' +def create_fresnel_spectrum_propagator( + N: np.ndarray, # probe dimensions ( WIDE, HIGH ) + beam_energy: float, # x-ray energy ( eV ) + delta_z: float, # meters + detector_dist: float, # meters + detector_pixel_width: float ) -> np.ndarray: # meters + + wavelength = ( 1.23984193e-9 / ( beam_energy / 1e3 )) # x-ray energy ( eV ), wavelength ( meters ) + + xgrid = ( 0.5 + np.linspace( ( -0.5 * N[1] ), ( 0.5 * N[1] - 1 ), num = N[1] )) / N[1] + ygrid = ( 0.5 + np.linspace( ( -0.5 * N[0] ), ( 0.5 * N[0] - 1 ), num = N[0] )) / N[0] + + x = wavelength * detector_dist / detector_pixel_width + #z_obj_L = np.asarray( [ x, x ], dtype = np.float32 ) + z_obj_L = np.asarray( [ x, x ], dtype = np.float64 ) + + kx = 2 * np.pi * N[0] * xgrid / z_obj_L[ 0 ] + ky = 2 * np.pi * N[1] * ygrid / z_obj_L[ 1 ] + + Kx, Ky = np.meshgrid(kx, ky, indexing='xy') + + fresnel_spectrum_propagator = np.exp( 1j * delta_z * np.sqrt( ( 2 * np.pi / wavelength ) ** 2 - Kx ** 2 - Ky ** 2 )) + + fresnel_spectrum_propagator = np.ndarray.astype( np.fft.fftshift( fresnel_spectrum_propagator ), dtype = np.csingle ) + + return fresnel_spectrum_propagator + +''' + + + # def _create_fresnel_spectrum_propagator( + # self, + # N: typing.Tuple[int, int], + # pixel_size: float = 1.0, + # distance: float = 1.0, + # wavelength: float = 1.0, + # ) -> np.ndarray: + # # FIXME: Check that dimension ordering is consistent + # rr2 = self.xp.linspace(-0.5 * N[1], 0.5 * N[1] - 1, num=N[1]) ** 2 + # cc2 = self.xp.linspace(-0.5 * N[0], 0.5 * N[0] - 1, num=N[0]) ** 2 + + # x = -1j * self.xp.pi * wavelength * distance + # rr2 = self.xp.exp(x * rr2[..., None] / (pixel_size**2)) + # cc2 = self.xp.exp(x * cc2[..., None] / (pixel_size**2)) + + # fresnel_spectrum_propagator = self.xp.ndarray.astype( + # self.xp.fft.fftshift(self.xp.outer(self.xp.transpose(rr2), cc2)), + # dtype=self.xp.csingle, + # ) + + # return fresnel_spectrum_propagator + + + diff --git a/src/tike/operators/cupy/multislice.py b/src/tike/operators/cupy/multislice.py index 8da1b6d6..f6775863 100644 --- a/src/tike/operators/cupy/multislice.py +++ b/src/tike/operators/cupy/multislice.py @@ -21,8 +21,11 @@ def __init__( self, detector_shape: int, probe_shape: int, + probe_wavelength: float, + probe_FOV_lengths: typing.Tuple[float, float], nz: int, n: int, + multislice_propagation_distance: float, propagation: typing.Type[Propagation] = FresnelSpectProp, diffraction: typing.Type[Convolution] = Convolution, norm: str = "ortho", @@ -37,7 +40,11 @@ def __init__( **kwargs, ) self.propagation = propagation( - detector_shape=detector_shape, + norm = norm, + probe_shape=probe_shape, + wavelength=probe_wavelength, + probe_FOV=probe_FOV_lengths, + distance=multislice_propagation_distance, **kwargs, ) @@ -46,7 +53,10 @@ def __init__( self.detector_shape = detector_shape self.nz = nz self.n = n - + self.probe_wavelength = probe_wavelength + self.probe_FOV_lengths = probe_FOV_lengths + self.multislice_propagation_distance = multislice_propagation_distance + def __enter__(self): self.propagation.__enter__() self.diffraction.__enter__() diff --git a/src/tike/operators/cupy/ptycho.py b/src/tike/operators/cupy/ptycho.py index 65c7bf77..befd99f4 100644 --- a/src/tike/operators/cupy/ptycho.py +++ b/src/tike/operators/cupy/ptycho.py @@ -67,8 +67,11 @@ def __init__( self, detector_shape: int, probe_shape: int, + probe_wavelength: float, + probe_FOV_lengths: typing.Tuple[float, float], nz: int, n: int, + multislice_propagation_distance: float, propagation: typing.Type[Propagation] = Propagation, diffraction: typing.Type[Multislice] = Multislice, norm: str = 'ortho', @@ -82,9 +85,12 @@ def __init__( ) self.diffraction = diffraction( probe_shape=probe_shape, + probe_wavelength=probe_wavelength, + probe_FOV_lengths=probe_FOV_lengths, detector_shape=detector_shape, nz=nz, n=n, + multislice_propagation_distance=multislice_propagation_distance, **kwargs, ) # TODO: Replace these with @property functions @@ -92,7 +98,10 @@ def __init__( self.detector_shape = detector_shape self.nz = nz self.n = n - + self.probe_wavelength = probe_wavelength + self.probe_FOV_lengths = probe_FOV_lengths + self.multislice_propagation_distance = multislice_propagation_distance + def __enter__(self): self.propagation.__enter__() self.diffraction.__enter__() diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 952b9d2c..ce529efe 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -73,7 +73,13 @@ class ObjectOptions: clip_magnitude: bool = False """Whether to force the object magnitude to remain <= 1.""" - + + multislice_propagation_distance: float = 1.0e-9 + """ If we're doing multislice ptychography, then this will be + the slice to slice distance along the probing wavefield propagation + direction. Units are meters. + """ + def copy_to_device(self) -> ObjectOptions: """Copy to the current GPU memory.""" options = ObjectOptions( @@ -84,6 +90,7 @@ def copy_to_device(self) -> ObjectOptions: vdecay=self.vdecay, mdecay=self.mdecay, clip_magnitude=self.clip_magnitude, + multislice_propagation_distance=self.multislice_propagation_distance, ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: @@ -117,6 +124,7 @@ def copy_to_host(self) -> ObjectOptions: vdecay=self.vdecay, mdecay=self.mdecay, clip_magnitude=self.clip_magnitude, + multislice_propagation_distance=self.multislice_propagation_distance, ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: @@ -137,6 +145,7 @@ def resample(self, factor: float, interp) -> ObjectOptions: vdecay=self.vdecay, mdecay=self.mdecay, clip_magnitude=self.clip_magnitude, + multislice_propagation_distance=self.multislice_propagation_distance, ) options.update_mnorm = copy.copy(self.update_mnorm) return options @@ -172,6 +181,7 @@ def join( vdecay=x[0].vdecay, mdecay=x[0].mdecay, clip_magnitude=x[0].clip_magnitude, + multislice_propagation_distance=x[0].multislice_propagation_distance, ) options.update_mnorm = copy.copy(x[0].update_mnorm) if x[0].v is not None: diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 4a365eaf..e51b0ca1 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -72,6 +72,17 @@ class ProbeOptions: it will default to the average of the measurement intensity scaling. """ + probe_wavelength: float = np.nan + """ The wavelength (meters) of the probing wavefield; + we assume a monochomatic (single wavelength) probe. + """ + + probe_FOV_lengths: typing.Tuple[float, float] = ( np.nan, np.nan ) + """ The transverse field of view (FOV) for the probe in + units of length (meters). The first element is vertical FOV, + the second element is horizontal FOV. + """ + force_orthogonality: bool = False """Forces probes to be orthogonal each iteration.""" @@ -164,6 +175,8 @@ def copy_to_device(self) -> ProbeOptions: update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, probe_photons=self.probe_photons, + probe_wavelength=self.probe_wavelength, + probe_FOV_lengths=self.probe_FOV_lengths, force_orthogonality=self.force_orthogonality, force_centered_intensity=self.force_centered_intensity, force_sparsity=self.force_sparsity, @@ -206,6 +219,8 @@ def copy_to_host(self) -> ProbeOptions: update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, probe_photons=self.probe_photons, + probe_wavelength=self.probe_wavelength, + probe_FOV_lengths=self.probe_FOV_lengths, force_orthogonality=self.force_orthogonality, force_centered_intensity=self.force_centered_intensity, force_sparsity=self.force_sparsity, @@ -235,6 +250,8 @@ def resample(self, factor: float, interp) -> ProbeOptions: update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, probe_photons=self.probe_photons, + probe_wavelength=self.probe_wavelength, + probe_FOV_lengths=self.probe_FOV_lengths, force_orthogonality=self.force_orthogonality, force_centered_intensity=self.force_centered_intensity, force_sparsity=self.force_sparsity, diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 409ea9c5..cf419266 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -349,6 +349,9 @@ def __init__( nz=parameters.psi.shape[-2], n=parameters.psi.shape[-1], norm=parameters.exitwave_options.propagation_normalization, + probe_wavelength=parameters.probe_options.probe_wavelength, + probe_FOV_lengths=parameters.probe_options.probe_FOV_lengths, + multislice_propagation_distance=parameters.object_options.multislice_propagation_distance, ) self.comm = tike.communicators.Comm(num_gpu, mpi) From 5c403b3df9129258574ea7793162c8f81090e5d0 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 31 Jul 2024 13:45:27 -0500 Subject: [PATCH 2/2] fixed multislice unit test where we need extra parameters (probe FOV, multislice prop distance, wavelength) --- tests/operators/test_multislice.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/operators/test_multislice.py b/tests/operators/test_multislice.py index f956737b..fb9d73b5 100644 --- a/tests/operators/test_multislice.py +++ b/tests/operators/test_multislice.py @@ -39,9 +39,12 @@ def setUp(self, depth=7, pw=15, nscan=27): self.operator = Multislice( nscan=self.scan_shape[-2], probe_shape=self.probe_shape[-1], + probe_wavelength = 1e-10, + probe_FOV_lengths = (1e-5, 1e-5), detector_shape=self.detector_shape[-1], nz=self.original_shape[-2], n=self.original_shape[-1], + multislice_propagation_distance = 1e-8, ) self.operator.__enter__() self.xp = self.operator.xp