From 0efb3644690f70c9f6dc548f8210e7dde9c3ea36 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 24 May 2024 20:26:38 -0500 Subject: [PATCH] API: Add depth dimension to psi --- src/tike/operators/cupy/convolution.py | 70 ---------------------- src/tike/operators/cupy/multislice.py | 15 ++++- src/tike/operators/cupy/ptycho.py | 38 +++--------- src/tike/ptycho/ptycho.py | 2 - src/tike/ptycho/solvers/_preconditioner.py | 6 +- src/tike/ptycho/solvers/dm.py | 16 +++-- src/tike/ptycho/solvers/lstsq.py | 8 +-- src/tike/ptycho/solvers/options.py | 6 +- src/tike/ptycho/solvers/rpie.py | 6 +- tests/operators/test_convolution.py | 36 ----------- tests/operators/test_ptycho.py | 45 ++------------ tests/ptycho/io.py | 4 +- tests/ptycho/templates.py | 2 +- tests/ptycho/test_position.py | 2 +- tests/ptycho/test_ptycho.py | 2 +- 15 files changed, 51 insertions(+), 207 deletions(-) diff --git a/src/tike/operators/cupy/convolution.py b/src/tike/operators/cupy/convolution.py index 0a82bbb7..4e7a1c16 100644 --- a/src/tike/operators/cupy/convolution.py +++ b/src/tike/operators/cupy/convolution.py @@ -132,73 +132,3 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False): patches = patches.conj() patches *= nearplane[..., self.pad:self.end, self.pad:self.end] return patches - - def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False): - """Peform adj and adj_probe at the same time.""" - assert probe.shape[:-4] == scan.shape[:-2] - assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape) - assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2] - assert nearplane.shape[:-3] == scan.shape[:-1], (nearplane.shape, - scan.shape) - - patches = self.patch.fwd( - # Could be xp.empty if scan positions are all in bounds - patches=self.xp.zeros_like( - psi, - shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3], - self.probe_shape, self.probe_shape), - ), - images=psi, - positions=scan, - patch_width=self.probe_shape, - nrepeat=nearplane.shape[-3], - ) - patches = patches.reshape((*scan.shape[:-1], nearplane.shape[-3], - self.probe_shape, self.probe_shape)) - if rpie: - patches_amp = self.xp.sum( - patches * patches.conj(), - axis=-4, - keepdims=True, - ) - patches = patches.conj() - patches *= nearplane[..., self.pad:self.end, self.pad:self.end] - - if not overwrite: - nearplane = nearplane.copy() - nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj() - if rpie: - probe_amp = probe * probe.conj() - probe_amp = probe_amp.reshape( - (*scan.shape[:-2], -1, *nearplane.shape[-2:]) - # (..., nscan * nprobe, probe_shape, probe_shape) - # (..., nprobe, probe_shape, probe_shape) - ) - probe_amp = self.patch.adj( - patches=probe_amp, - images=self.xp.zeros_like( - psi, - shape=(*scan.shape[:-2], self.nz, self.n), - ), - positions=scan, - patch_width=self.probe_shape, - nrepeat=nearplane.shape[-3], - ) - - apsi = self.patch.adj( - patches=nearplane.reshape( - (*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3], - *nearplane.shape[-2:])), - images=self.xp.zeros_like( - psi, - shape=(*scan.shape[:-2], self.nz, self.n), - ), - positions=scan, - patch_width=self.probe_shape, - nrepeat=nearplane.shape[-3], - ) - - if rpie: - return apsi, patches, patches_amp, probe_amp - else: - return apsi, patches diff --git a/src/tike/operators/cupy/multislice.py b/src/tike/operators/cupy/multislice.py index 535d3fdd..a7156806 100644 --- a/src/tike/operators/cupy/multislice.py +++ b/src/tike/operators/cupy/multislice.py @@ -97,8 +97,9 @@ def fwd( **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(SingleSlice) for more info.""" + assert (psi.shape[0] == 1 and psi.ndim == 3) return self.diffraction.fwd( - psi=psi, + psi=psi[0], scan=scan, probe=probe, ) @@ -113,12 +114,22 @@ def adj( **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(SingleSlice) for more info.""" + assert psi is None or (psi.shape[0] == 1 and psi.ndim == 3) return self.diffraction.adj( nearplane=nearplane, probe=probe, scan=scan, overwrite=True, - psi=psi, + psi=psi[0] if psi is not None else None, + )[None, ...] + + def adj_probe(self, nearplane, scan, psi, overwrite=False): + assert (psi.shape[0] == 1 and psi.ndim == 3) + return self.diffraction.adj_probe( + nearplane=nearplane, + scan=scan, + psi=psi[0], + overwrite=overwrite, ) @property diff --git a/src/tike/operators/cupy/ptycho.py b/src/tike/operators/cupy/ptycho.py index ee330d84..01d004f3 100644 --- a/src/tike/operators/cupy/ptycho.py +++ b/src/tike/operators/cupy/ptycho.py @@ -35,30 +35,31 @@ class Ptycho(Operator): ---------- detector_shape : int The pixel width and height of the (square) detector grid. - nz, n : int - The pixel width and height of the reconstructed grid. + d, nz, n : int + The pixel depth, width, and height of the reconstructed grid. probe_shape : int The pixel width and height of the (square) probe illumination. propagation : :py:class:`Operator` The wave propagation operator being used. diffraction : :py:class:`Operator` The object probe interaction operator being used. - data : (..., FRAME, WIDE, HIGH) float32 + data : (FRAME, WIDE, HIGH) float32 The intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. - farplane: (..., POSI, 1, SHARED, detector_shape, detector_shape) complex64 + farplane: (POSI, 1, SHARED, detector_shape, detector_shape) complex64 The wavefronts hitting the detector respectively. - probe : {(..., 1, 1, SHARED, WIDE, HIGH), (..., POSI, 1, SHARED, WIDE, HIGH)} complex64 + probe : {(1, 1, SHARED, WIDE, HIGH), (POSI, 1, SHARED, WIDE, HIGH)} complex64 The complex illumination function. - psi : (..., WIDE, HIGH) complex64 + psi : (DEPTH, WIDE, HIGH) complex64 The wavefront modulation coefficients of the object. - scan : (..., POSI, 2) float32 + scan : (POSI, 2) float32 Coordinates of the minimum corner of the probe grid for each measurement in the coordinate system of psi. Coordinate order consistent with WIDE, HIGH order. .. versionchanged:: 0.25.0 Removed the model and ntheta parameters. + .. versionchanged:: 0.26.0 Added depth dimension to psi array """ @@ -246,26 +247,3 @@ def grad_probe( axis=0, keepdims=True, ) - - def adj_all( - self, - farplane: npt.NDArray[np.csingle], - probe: npt.NDArray[np.csingle], - scan: npt.NDArray[np.single], - psi: npt.NDArray[np.csingle], - overwrite: bool = False, - rpie: bool = False, - ) -> typing.Tuple[npt.NDArray, ...]: - """Please see help(Ptycho) for more info.""" - result = self.diffraction.adj_all( - nearplane=self.propagation.adj( - farplane, - overwrite=overwrite, - )[..., 0, :, :, :], - probe=probe[..., 0, :, :, :], - scan=scan, - overwrite=True, - psi=psi, - rpie=rpie, - ) - return (result[0], result[1][..., None, :, :, :], *result[2:]) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 267a3f92..29f755e5 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -97,7 +97,6 @@ def _compute_intensity( eigen_probe=None, fly=1, ): - leading = psi.shape[:-2] intensity = 0 for m in range(probe.shape[-3]): farplane = operator.fwd( @@ -111,7 +110,6 @@ def _compute_intensity( ) intensity += np.sum( np.square(np.abs(farplane)).reshape( - *leading, scan.shape[-2] // fly, fly, operator.detector_shape, diff --git a/src/tike/ptycho/solvers/_preconditioner.py b/src/tike/ptycho/solvers/_preconditioner.py index adb9a57b..8e4812ed 100644 --- a/src/tike/ptycho/solvers/_preconditioner.py +++ b/src/tike/ptycho/solvers/_preconditioner.py @@ -32,8 +32,9 @@ def _psi_preconditioner( operator: tike.operators.Ptycho, ) -> npt.NDArray: + # FIXME: Generated only one preconditioner for all slices psi_update_denominator = cp.zeros( - shape=psi.shape, + shape=psi.shape[-2:], dtype=psi.dtype, ) @@ -92,8 +93,9 @@ def make_certain_args_constant( ) -> None: nonlocal probe_update_denominator + # FIXME: Only use the first slice for the probe preconditioner patches = operator.diffraction.patch.fwd( - images=psi, + images=psi[0], positions=scan[lo:hi], patch_width=probe.shape[-1], ) diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index c4c43225..14ff3590 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -214,24 +214,22 @@ def keep_some_args_constant( nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] - patches = op.diffraction.patch.fwd( - patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), - images=psi, - positions=scan[lo:hi], - )[..., None, None, :, :] - if object_options: - grad_psi = (cp.conj(varying_probe) * nearplane).reshape( (hi - lo) * probe.shape[-3], *probe.shape[-2:]) - psi_update_numerator = op.diffraction.patch.adj( + psi_update_numerator[0] = op.diffraction.patch.adj( patches=grad_psi, - images=psi_update_numerator, + images=psi_update_numerator[0], positions=scan[lo:hi], nrepeat=probe.shape[-3], ) if probe_options: + patches = op.diffraction.patch.fwd( + patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), + images=psi[0], + positions=scan[lo:hi], + )[..., None, None, :, :] probe_update_numerator += cp.sum( cp.conj(patches) * nearplane, axis=-5, diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 49d09e1a..80df5a63 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -588,10 +588,10 @@ def keep_some_args_constant( # (24b) object_update_proj = cp.conj(bunique_probe[blo:bhi]) * bchi[blo:bhi] # (25b) Common object gradient. - object_upd_sum = op.diffraction.patch.adj( + object_upd_sum[0] = op.diffraction.patch.adj( patches=object_update_proj.reshape( len(scan[lo:hi]) * bchi.shape[-3], *bchi.shape[-2:]), - images=object_upd_sum, + images=object_upd_sum[0], positions=scan[lo:hi], nrepeat=bchi.shape[-3], ) @@ -601,7 +601,7 @@ def keep_some_args_constant( if recover_probe: bpatches[blo:bhi] = op.diffraction.patch.fwd( patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]), - images=psi, + images=psi[0], positions=scan[lo:hi], )[..., None, None, :, :] # (24a) @@ -724,7 +724,7 @@ def _precondition_nearplane_gradients( object_update_proj = op.diffraction.patch.fwd( patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), - images=object_update_precond, + images=object_update_precond[0], positions=scan[lo:hi], ) dOP = object_update_proj[..., None, diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index c1bd46e7..d342ad6e 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -158,12 +158,12 @@ def __post_init__(self): "It should be (1, 1, S, W, H) " "where S >=1 is the number of probes, and " "W, H >= 1 are the square probe grid dimensions.") - if (self.psi.ndim != 2 or np.any( - np.asarray(self.psi.shape) <= np.asarray(self.probe.shape[-2:])) + if (self.psi.ndim != 3 or np.any( + np.asarray(self.psi.shape[-2:]) <= np.asarray(self.probe.shape[-2:])) ): raise ValueError( f"psi shape {self.psi.shape} is incorrect. " - "It should be (W, H) where W, H > probe.shape[-2:].") + "It should be (D, W, H) where W, H > probe.shape[-2:].") check_allowed_positions( self.scan, self.psi, diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index b5a83aa4..ec772a33 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -452,9 +452,9 @@ def keep_some_args_constant( if object_options: grad_psi = (cp.conj(unique_probe) * diff / probe.shape[-3]).reshape( scan[lo:hi].shape[0] * probe.shape[-3], *probe.shape[-2:]) - psi_update_numerator = op.diffraction.patch.adj( + psi_update_numerator[0] = op.diffraction.patch.adj( patches=grad_psi, - images=psi_update_numerator, + images=psi_update_numerator[0], positions=scan[lo:hi], nrepeat=probe.shape[-3], ) @@ -463,7 +463,7 @@ def keep_some_args_constant( patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), - images=psi, + images=psi[0], positions=scan[lo:hi], )[..., None, None, :, :] diff --git a/tests/operators/test_convolution.py b/tests/operators/test_convolution.py index b32caf0b..24bcacf8 100644 --- a/tests/operators/test_convolution.py +++ b/tests/operators/test_convolution.py @@ -97,41 +97,5 @@ def test_adj_probe_time(self): def test_scaled(self): pass - def test_adjoint_all(self): - """Check that the adjoint operator is correct.""" - d = self.operator.fwd( - **{ - self.m_name: self.m, - self.m1_name: self.m1 - }, - **self.kwargs2, - ) - assert d.shape == self.d.shape - m, m1 = self.operator.adj_all( - **{ - self.d_name: self.d, - self.m_name: self.m, - self.m1_name: self.m1 - }, - **self.kwargs2, - ) - assert m.shape == self.m.shape - assert m1.shape == self.m1.shape - a = tike.linalg.inner(d, self.d) - b = tike.linalg.inner(self.m, m) - c = tike.linalg.inner(self.m1, m1) - print() - print('< Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(), - a.imag.item())) - print('< d0, F*d0> = {:.6f}{:+.6f}j'.format(b.real.item(), - b.imag.item())) - print('< d1, F*d1> = {:.6f}{:+.6f}j'.format(c.real.item(), - c.imag.item())) - self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) - - if __name__ == '__main__': unittest.main() diff --git a/tests/operators/test_ptycho.py b/tests/operators/test_ptycho.py index 6adaa4d7..ac5eccf3 100644 --- a/tests/operators/test_ptycho.py +++ b/tests/operators/test_ptycho.py @@ -19,15 +19,14 @@ class TestPtycho(unittest.TestCase, OperatorTests): """Test the ptychography operator.""" - def setUp(self, ntheta=3, pw=15, nscan=27): + def setUp(self, depth=1, pw=15, nscan=27): """Load a dataset for reconstruction.""" self.nscan = nscan - self.ntheta = ntheta self.nprobe = 3 - self.probe_shape = (ntheta, nscan, 1, self.nprobe, pw, pw) + self.probe_shape = (nscan, 1, self.nprobe, pw, pw) self.detector_shape = (pw * 3, pw * 3) - self.original_shape = (ntheta, 128, 128) - self.scan_shape = (ntheta, nscan, 2) + self.original_shape = (depth, 128, 128) + self.scan_shape = (nscan, 2) print(Ptycho) np.random.seed(0) @@ -43,7 +42,6 @@ def setUp(self, ntheta=3, pw=15, nscan=27): detector_shape=self.detector_shape[-1], nz=self.original_shape[-2], n=self.original_shape[-1], - ntheta=self.ntheta, ) self.operator.__enter__() self.xp = self.operator.xp @@ -93,41 +91,6 @@ def test_adj_probe_time(self): def test_scaled(self): pass - def test_adjoint_all(self): - """Check that the adjoint operator is correct.""" - d = self.operator.fwd( - **{ - self.m_name: self.m, - self.m1_name: self.m1 - }, - **self.kwargs2, - ) - assert d.shape == self.d.shape - m, m1 = self.operator.adj_all( - **{ - self.d_name: self.d, - self.m_name: self.m, - self.m1_name: self.m1 - }, - **self.kwargs2, - ) - assert m.shape == self.m.shape - assert m1.shape == self.m1.shape - a = tike.linalg.inner(d, self.d) - b = tike.linalg.inner(self.m, m) - c = tike.linalg.inner(self.m1, m1) - print() - print('< Fm, m> = {:.5g}{:+.5g}j'.format(a.real.item(), - a.imag.item())) - print('< d0, F*d0> = {:.5g}{:+.5g}j'.format(b.real.item(), - b.imag.item())) - print('< d1, F*d1> = {:.5g}{:+.5g}j'.format(c.real.item(), - c.imag.item())) - self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) - if __name__ == '__main__': unittest.main() diff --git a/tests/ptycho/io.py b/tests/ptycho/io.py index 3ecba0cf..3b0f9b14 100644 --- a/tests/ptycho/io.py +++ b/tests/ptycho/io.py @@ -108,7 +108,7 @@ def _save_ptycho_result(result, algorithm): plt.close(fig) plt.imsave( f'{fname}/{0}-phase.png', - np.angle(result.psi).astype('float32'), + np.sum(np.angle(result.psi).astype('float32'), axis=0), # The output of np.angle is locked to (-pi, pi] cmap=plt.cm.twilight, vmin=-np.pi, @@ -116,7 +116,7 @@ def _save_ptycho_result(result, algorithm): ) plt.imsave( f'{fname}/{0}-ampli.png', - np.abs(result.psi).astype('float32'), + np.sum(np.abs(result.psi).astype('float32'), axis=0), cmap=plt.cm.gray, ) _save_probe(fname, result.probe, result.probe_options, algorithm) diff --git a/tests/ptycho/templates.py b/tests/ptycho/templates.py index 3cacec65..3a2cf133 100644 --- a/tests/ptycho/templates.py +++ b/tests/ptycho/templates.py @@ -39,7 +39,7 @@ def setUp(self, filename='siemens-star-small.npz.bz2'): self.data = self.data[mask] self.psi = np.full( - (600, 600), + (1, 600, 600), dtype=np.complex64, fill_value=np.complex64(0.5 + 0j), ) diff --git a/tests/ptycho/test_position.py b/tests/ptycho/test_position.py index ab0da0ac..12482e54 100644 --- a/tests/ptycho/test_position.py +++ b/tests/ptycho/test_position.py @@ -174,7 +174,7 @@ def setUp(self, filename='position-error-247.pickle.bz2'): self.data = self.data[mask] self.psi = np.full( - (600, 600), + (1, 600, 600), dtype=np.complex64, fill_value=np.complex64(0.5 + 0j), ) diff --git a/tests/ptycho/test_ptycho.py b/tests/ptycho/test_ptycho.py index 40672ade..9bc0287a 100644 --- a/tests/ptycho/test_ptycho.py +++ b/tests/ptycho/test_ptycho.py @@ -90,7 +90,7 @@ def test_gaussian(self): np.testing.assert_array_equal(weights, truth) def test_check_allowed_positions(self): - psi = np.empty((4, 9)) + psi = np.empty((1, 4, 9)) probe = np.empty((8, 2, 2)) scan = np.array([[1, 1], [1, 6.9], [1.1, 1], [1.9, 5.5]]) tike.ptycho.check_allowed_positions(scan, psi, probe.shape)