From 399def6381bf64175ca4eeaf46aae38c478cd48a Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 29 May 2024 15:56:20 -0500 Subject: [PATCH 1/9] NEW: Implement adjoint test for multislice prototype --- src/tike/operators/cupy/multislice.py | 112 +++++++++++++------ tests/operators/test_multislice.py | 148 ++++++++++++++++++++++++++ 2 files changed, 230 insertions(+), 30 deletions(-) create mode 100644 tests/operators/test_multislice.py diff --git a/src/tike/operators/cupy/multislice.py b/src/tike/operators/cupy/multislice.py index 2714c387..04d0f12a 100644 --- a/src/tike/operators/cupy/multislice.py +++ b/src/tike/operators/cupy/multislice.py @@ -46,6 +46,82 @@ def __init__( self.n = n self.nslices = nslices + def __enter__(self): + self.propagation.__enter__() + self.diffraction.__enter__() + return self + + def __exit__(self, type, value, traceback): + self.propagation.__exit__(type, value, traceback) + self.diffraction.__exit__(type, value, traceback) + + def fwd( + self, + probe: npt.NDArray[np.csingle], + scan: npt.NDArray[np.single], + psi: npt.NDArray[np.csingle], + **kwargs, + ) -> npt.NDArray[np.csingle]: + """Please see help(SingleSlice) for more info.""" + assert psi.shape[0] == self.nslices and psi.ndim == 3 + exitwave = probe + for s in range(self.nslices): + exitwave = self.diffraction.fwd( + psi=psi[s], + scan=scan, + probe=exitwave, + ) + return exitwave + + def adj( + self, + nearplane: npt.NDArray[np.csingle], + probe: npt.NDArray[np.csingle], + scan: npt.NDArray[np.single], + psi: npt.NDArray[np.csingle], + overwrite: bool = False, + **kwargs, + ) -> npt.NDArray[np.csingle]: + """Please see help(SingleSlice) for more info.""" + probe_adj = nearplane + psi_adj = self.xp.zeros_like(psi) + exitwave = [ + None, + ] * len(psi) + exitwave[0] = probe + for s in range(1, self.nslices): + exitwave[s] = self.diffraction.fwd( + psi=psi[s - 1], + scan=scan, + probe=exitwave[s - 1], + ) + for s in range(self.nslices - 1, -1, -1): + psi_adj[s] = self.diffraction.adj( + nearplane=probe_adj, + probe=exitwave[s], + scan=scan, + overwrite=False, + ) + probe_adj = self.diffraction.adj_probe( + nearplane=probe_adj, + scan=scan, + psi=psi[s], + ) + # FIXME: Why is division by nslices needed here? + return psi_adj / self.nslices, probe_adj + + @property + def patch(self): + return self.diffraction.patch + + @property + def pad(self): + return self.diffraction.pad + + @property + def end(self): + return self.diffraction.end + class SingleSlice(Multislice): """Single slice wavefield propgation""" @@ -80,15 +156,6 @@ def __init__( self.n = n self.nslices = 1 - def __enter__(self): - self.propagation.__enter__() - self.diffraction.__enter__() - return self - - def __exit__(self, type, value, traceback): - self.propagation.__exit__(type, value, traceback) - self.diffraction.__exit__(type, value, traceback) - def fwd( self, probe: npt.NDArray[np.csingle], @@ -97,7 +164,7 @@ def fwd( **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(SingleSlice) for more info.""" - assert (psi.shape[0] == 1 and psi.ndim == 3) + assert psi.shape[0] == 1 and psi.ndim == 3 return self.diffraction.fwd( psi=psi[0], scan=scan, @@ -115,31 +182,16 @@ def adj( ) -> npt.NDArray[np.csingle]: """Please see help(SingleSlice) for more info.""" assert psi is None or (psi.shape[0] == 1 and psi.ndim == 3) - return self.diffraction.adj( + psi_adj = self.diffraction.adj( nearplane=nearplane, probe=probe, scan=scan, - overwrite=True, - psi=psi[0] if psi is not None else None, + overwrite=False, )[None, ...] - - def adj_probe(self, nearplane, scan, psi, overwrite=False): - assert (psi.shape[0] == 1 and psi.ndim == 3) - return self.diffraction.adj_probe( + probe_adj = self.diffraction.adj_probe( nearplane=nearplane, scan=scan, psi=psi[0], - overwrite=overwrite, + overwrite=False, ) - - @property - def patch(self): - return self.diffraction.patch - - @property - def pad(self): - return self.diffraction.pad - - @property - def end(self): - return self.diffraction.end + return psi_adj, probe_adj diff --git a/tests/operators/test_multislice.py b/tests/operators/test_multislice.py new file mode 100644 index 00000000..8af3fe26 --- /dev/null +++ b/tests/operators/test_multislice.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import time +import unittest + +import numpy as np +from tike.operators import Multislice +from tike.operators.cupy.multislice import SingleSlice +import tike.precision +import tike.linalg + +from .util import random_complex, OperatorTests + +__author__ = "Daniel Ching" +__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC." +__docformat__ = 'restructuredtext en' + + +class TestMultislice(unittest.TestCase): + """Test the ptychography operator.""" + + def setUp(self, depth=7, pw=15, nscan=27): + """Load a dataset for reconstruction.""" + self.nscan = nscan + self.nprobe = 3 + self.probe_shape = (nscan, self.nprobe, pw, pw) + self.detector_shape = (pw, pw) + self.original_shape = (depth, 128, 128) + self.scan_shape = (nscan, 2) + print(Multislice) + + np.random.seed(0) + scan = np.random.rand(*self.scan_shape).astype( + tike.precision.floating) * (127 - 16) + probe = random_complex(*self.probe_shape) + original = random_complex(*self.original_shape) + farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape) + + self.operator = Multislice( + nscan=self.scan_shape[-2], + probe_shape=self.probe_shape[-1], + detector_shape=self.detector_shape[-1], + nz=self.original_shape[-2], + n=self.original_shape[-1], + nslices=depth, + ) + self.operator.__enter__() + self.xp = self.operator.xp + + self.mkwargs = { + 'probe': self.xp.asarray(probe), + 'psi': self.xp.asarray(original), + 'scan': self.xp.asarray(scan), + } + self.dkwargs = { + 'nearplane': self.xp.asarray(farplane), + 'probe': self.xp.asarray(probe), + 'psi': self.xp.asarray(original), + 'scan': self.xp.asarray(scan), + } + + def test_adjoint(self): + """Check that the adjoint operator is correct.""" + d = self.operator.fwd(**self.mkwargs) + assert d.shape == self.dkwargs['nearplane'].shape + m0, m1 = self.operator.adj(**self.dkwargs) + assert m0.shape == self.mkwargs['psi'].shape + assert m1.shape == self.mkwargs['probe'].shape + a = tike.linalg.inner(d, self.dkwargs['nearplane']) + b = tike.linalg.inner(self.mkwargs['psi'], m0) + c = tike.linalg.inner(self.mkwargs['probe'], m1) + print() + print(' = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item())) + print('< m0, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item())) + print('< m1, F*d> = {:.5g}{:+.5g}j'.format(c.real.item(), c.imag.item())) + # self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) + # self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) + + @unittest.skip('FIXME: This operator is not scaled.') + def test_scaled(self): + pass + + +class TestSingleslice(unittest.TestCase): + """Test the ptychography operator.""" + + def setUp(self, depth=1, pw=15, nscan=27): + """Load a dataset for reconstruction.""" + self.nscan = nscan + self.nprobe = 3 + self.probe_shape = (nscan, self.nprobe, pw, pw) + self.detector_shape = (pw, pw) + self.original_shape = (depth, 128, 128) + self.scan_shape = (nscan, 2) + print(Multislice) + + np.random.seed(0) + scan = np.random.rand(*self.scan_shape).astype( + tike.precision.floating) * (127 - 16) + probe = random_complex(*self.probe_shape) + original = random_complex(*self.original_shape) + farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape) + + self.operator = SingleSlice( + nscan=self.scan_shape[-2], + probe_shape=self.probe_shape[-1], + detector_shape=self.detector_shape[-1], + nz=self.original_shape[-2], + n=self.original_shape[-1], + nslices=depth, + ) + self.operator.__enter__() + self.xp = self.operator.xp + + self.mkwargs = { + 'probe': self.xp.asarray(probe), + 'psi': self.xp.asarray(original), + 'scan': self.xp.asarray(scan), + } + self.dkwargs = { + 'nearplane': self.xp.asarray(farplane) + } + + def test_adjoint(self): + """Check that the adjoint operator is correct.""" + d = self.operator.fwd(**self.mkwargs) + assert d.shape == self.dkwargs['nearplane'].shape + m0, m1 = self.operator.adj(**self.mkwargs, **self.dkwargs) + assert m0.shape == self.mkwargs['psi'].shape + assert m1.shape == self.mkwargs['probe'].shape + a = tike.linalg.inner(d, self.dkwargs['nearplane']) + b = tike.linalg.inner(self.mkwargs['psi'], m0) + c = tike.linalg.inner(self.mkwargs['probe'], m1) + print() + print(' = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item())) + print(' = {:.5g}{:+.5g}j'.format(c.real.item(), c.imag.item())) + print('< d, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item())) + self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) + self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) + + @unittest.skip('FIXME: This operator is not scaled.') + def test_scaled(self): + pass + + +if __name__ == '__main__': + unittest.main() From 832a5a66753224cb562ca8d44c33a66d59e8e4b5 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 29 May 2024 16:01:02 -0500 Subject: [PATCH 2/9] TST: Fixup single slice test --- tests/operators/test_multislice.py | 43 +++++++----------------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/tests/operators/test_multislice.py b/tests/operators/test_multislice.py index 8af3fe26..5e8c7373 100644 --- a/tests/operators/test_multislice.py +++ b/tests/operators/test_multislice.py @@ -5,8 +5,7 @@ import unittest import numpy as np -from tike.operators import Multislice -from tike.operators.cupy.multislice import SingleSlice +from tike.operators import Multislice, SingleSlice import tike.precision import tike.linalg @@ -17,7 +16,7 @@ __docformat__ = 'restructuredtext en' -class TestMultislice(unittest.TestCase): +class TestMultiSlice(unittest.TestCase): """Test the ptychography operator.""" def setUp(self, depth=7, pw=15, nscan=27): @@ -55,16 +54,13 @@ def setUp(self, depth=7, pw=15, nscan=27): } self.dkwargs = { 'nearplane': self.xp.asarray(farplane), - 'probe': self.xp.asarray(probe), - 'psi': self.xp.asarray(original), - 'scan': self.xp.asarray(scan), } def test_adjoint(self): """Check that the adjoint operator is correct.""" d = self.operator.fwd(**self.mkwargs) assert d.shape == self.dkwargs['nearplane'].shape - m0, m1 = self.operator.adj(**self.dkwargs) + m0, m1 = self.operator.adj(**self.dkwargs, **self.mkwargs) assert m0.shape == self.mkwargs['psi'].shape assert m1.shape == self.mkwargs['probe'].shape a = tike.linalg.inner(d, self.dkwargs['nearplane']) @@ -74,15 +70,17 @@ def test_adjoint(self): print(' = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item())) print('< m0, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item())) print('< m1, F*d> = {:.5g}{:+.5g}j'.format(c.real.item(), c.imag.item())) - # self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) - # self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) + self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) + self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) + self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0) + self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) @unittest.skip('FIXME: This operator is not scaled.') def test_scaled(self): pass -class TestSingleslice(unittest.TestCase): +class TestSingleSlice(TestMultiSlice): """Test the ptychography operator.""" def setUp(self, depth=1, pw=15, nscan=27): @@ -102,7 +100,7 @@ def setUp(self, depth=1, pw=15, nscan=27): original = random_complex(*self.original_shape) farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape) - self.operator = SingleSlice( + self.operator = Multislice( nscan=self.scan_shape[-2], probe_shape=self.probe_shape[-1], detector_shape=self.detector_shape[-1], @@ -119,30 +117,9 @@ def setUp(self, depth=1, pw=15, nscan=27): 'scan': self.xp.asarray(scan), } self.dkwargs = { - 'nearplane': self.xp.asarray(farplane) + 'nearplane': self.xp.asarray(farplane), } - def test_adjoint(self): - """Check that the adjoint operator is correct.""" - d = self.operator.fwd(**self.mkwargs) - assert d.shape == self.dkwargs['nearplane'].shape - m0, m1 = self.operator.adj(**self.mkwargs, **self.dkwargs) - assert m0.shape == self.mkwargs['psi'].shape - assert m1.shape == self.mkwargs['probe'].shape - a = tike.linalg.inner(d, self.dkwargs['nearplane']) - b = tike.linalg.inner(self.mkwargs['psi'], m0) - c = tike.linalg.inner(self.mkwargs['probe'], m1) - print() - print(' = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item())) - print(' = {:.5g}{:+.5g}j'.format(c.real.item(), c.imag.item())) - print('< d, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item())) - self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) - self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) - - @unittest.skip('FIXME: This operator is not scaled.') - def test_scaled(self): - pass - if __name__ == '__main__': unittest.main() From 3a92fbbda7f577234c5ad96074d24e36cbab2188 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 29 May 2024 16:23:51 -0500 Subject: [PATCH 3/9] NEW: Add propagator into Multislice Operator --- src/tike/operators/cupy/multislice.py | 43 +++++++++++++++++++-------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/src/tike/operators/cupy/multislice.py b/src/tike/operators/cupy/multislice.py index 04d0f12a..53923bed 100644 --- a/src/tike/operators/cupy/multislice.py +++ b/src/tike/operators/cupy/multislice.py @@ -21,7 +21,7 @@ def __init__( probe_shape: int, nz: int, n: int, - propagation: typing.Type[Propagation] = ZeroPropagation, + propagation: typing.Type[Propagation] = Propagation, diffraction: typing.Type[Convolution] = Convolution, norm: str = "ortho", nslices: int = 1, @@ -64,12 +64,16 @@ def fwd( ) -> npt.NDArray[np.csingle]: """Please see help(SingleSlice) for more info.""" assert psi.shape[0] == self.nslices and psi.ndim == 3 - exitwave = probe - for s in range(self.nslices): + exitwave = self.diffraction.fwd( + psi=psi[0], + scan=scan, + probe=probe, + ) + for s in range(1, self.nslices): exitwave = self.diffraction.fwd( psi=psi[s], scan=scan, - probe=exitwave, + probe=self.propagation.fwd(exitwave), ) return exitwave @@ -83,22 +87,35 @@ def adj( **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(SingleSlice) for more info.""" - probe_adj = nearplane psi_adj = self.xp.zeros_like(psi) - exitwave = [ + probes = [ None, ] * len(psi) - exitwave[0] = probe + probes[0] = probe for s in range(1, self.nslices): - exitwave[s] = self.diffraction.fwd( - psi=psi[s - 1], - scan=scan, - probe=exitwave[s - 1], + probes[s] = self.propagation.fwd( + self.diffraction.fwd( + psi=psi[s - 1], + scan=scan, + probe=probes[s - 1], + ) ) - for s in range(self.nslices - 1, -1, -1): + psi_adj[self.nslices - 1] = self.diffraction.adj( + nearplane=nearplane, + probe=probes[self.nslices - 1], + scan=scan, + overwrite=False, + ) + probe_adj = self.diffraction.adj_probe( + nearplane=nearplane, + scan=scan, + psi=psi[self.nslices - 1], + ) + for s in range(self.nslices - 2, -1, -1): + probe_adj = self.propagation.adj(probe_adj) psi_adj[s] = self.diffraction.adj( nearplane=probe_adj, - probe=exitwave[s], + probe=probes[s], scan=scan, overwrite=False, ) From 8c4534266c847e1827779181fab066d8c24e48cc Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 30 May 2024 12:21:45 -0500 Subject: [PATCH 4/9] REF: Move difference map to new multi-slice API --- src/tike/operators/cupy/convolution.py | 4 +-- src/tike/ptycho/solvers/dm.py | 34 ++++++++++---------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/tike/operators/cupy/convolution.py b/src/tike/operators/cupy/convolution.py index 92b343be..970d92df 100644 --- a/src/tike/operators/cupy/convolution.py +++ b/src/tike/operators/cupy/convolution.py @@ -102,9 +102,9 @@ def fwd(self, psi, scan, probe): def adj(self, nearplane, scan, probe, psi=None, overwrite=False): """Combine probe shaped patches into a psi shaped grid by addition.""" - assert probe.shape[:-4] == scan.shape[:-2] + assert probe.shape[:-4] == scan.shape[:-2], (probe.shape, scan.shape) assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2] - assert nearplane.shape[:-3] == scan.shape[:-1] + assert nearplane.shape[:-3] == scan.shape[:-1], (nearplane.shape, scan.shape) if not overwrite: nearplane = nearplane.copy() nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj() diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index 48b3851a..30d39eef 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -215,27 +215,19 @@ def keep_some_args_constant( nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] - if object_options: - grad_psi = (cp.conj(varying_probe) * nearplane).reshape( - (hi - lo) * probe.shape[-3], *probe.shape[-2:]) - psi_update_numerator[0] = op.diffraction.patch.adj( - patches=grad_psi, - images=psi_update_numerator[0], - positions=scan[lo:hi], - nrepeat=probe.shape[-3], - ) - - if probe_options: - patches = op.diffraction.patch.fwd( - patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), - images=psi[0], - positions=scan[lo:hi], - )[..., None, None, :, :] - probe_update_numerator += cp.sum( - cp.conj(patches) * nearplane, - axis=-5, - keepdims=True, - ) + grad_psi, grad_probe = op.diffraction.adj( + nearplane=nearplane[:, 0], + probe=varying_probe[:, 0], + scan=scan[lo:hi], + psi=psi, + ) + grad_probe = cp.sum( + grad_probe[:, None], + axis=-5, + keepdims=True, + ) + psi_update_numerator += grad_psi + probe_update_numerator += grad_probe tike.communicators.stream.stream_and_modify2( f=keep_some_args_constant, From 8093132bc713a427058769acc12ea8d4abb23213 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 31 May 2024 12:31:38 -0500 Subject: [PATCH 5/9] TST: Fix Ptycho operator adjoint tests --- src/tike/operators/cupy/ptycho.py | 85 ++----------------------------- tests/operators/test_ptycho.py | 55 ++++++++------------ 2 files changed, 24 insertions(+), 116 deletions(-) diff --git a/src/tike/operators/cupy/ptycho.py b/src/tike/operators/cupy/ptycho.py index 01d004f3..00c6dd2b 100644 --- a/src/tike/operators/cupy/ptycho.py +++ b/src/tike/operators/cupy/ptycho.py @@ -124,12 +124,12 @@ def adj( farplane: npt.NDArray[np.csingle], probe: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], - psi: npt.NDArray[np.csingle] = None, + psi: npt.NDArray[np.csingle], overwrite: bool = False, **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(Ptycho) for more info.""" - return self.diffraction.adj( + psi_adj, probe_adj = self.diffraction.adj( nearplane=self.propagation.adj( farplane, overwrite=overwrite, @@ -139,25 +139,7 @@ def adj( overwrite=True, psi=psi, ) - - def adj_probe( - self, - farplane: npt.NDArray[np.csingle], - scan: npt.NDArray[np.single], - psi: npt.NDArray[np.csingle], - overwrite: bool = False, - **kwargs, - ) -> npt.NDArray[np.csingle]: - """Please see help(Ptycho) for more info.""" - return self.diffraction.adj_probe( - psi=psi, - scan=scan, - nearplane=self.propagation.adj( - farplane=farplane, - overwrite=overwrite, - )[..., 0, :, :, :], - overwrite=True, - )[..., None, :, :, :] + return psi_adj, probe_adj[..., None, :, :, :] def _compute_intensity( self, @@ -186,64 +168,3 @@ def cost( """Please see help(Ptycho) for more info.""" intensity, _ = self._compute_intensity(data, psi, scan, probe) return getattr(objective, model)(data, intensity) - - def grad_psi( - self, - data: npt.NDArray, - psi: npt.NDArray[np.csingle], - scan: npt.NDArray[np.single], - probe: npt.NDArray[np.csingle], - *, - model: str, - ) -> npt.NDArray[np.csingle]: - """Please see help(Ptycho) for more info.""" - intensity, farplane = self._compute_intensity(data, psi, scan, probe) - grad_obj = self.xp.zeros_like(psi) - grad_obj = self.adj( - farplane=getattr(objective, f'{model}_grad')( - data, - farplane, - intensity, - ), - probe=probe, - scan=scan, - psi=grad_obj, - overwrite=True, - ) - return grad_obj - - def grad_probe( - self, - data: npt.NDArray, - psi: npt.NDArray[np.csingle], - scan: npt.NDArray[np.single], - probe: npt.NDArray[np.csingle], - mode: typing.List[int] = None, - *, - model: str, - ) -> npt.NDArray[np.csingle]: - """Compute the gradient with respect to the probe(s). - - Parameters - ---------- - mode : list(int) - Only return the gradient with resepect to these probes. - - """ - mode = list(range(probe.shape[-3])) if mode is None else mode - intensity, farplane = self._compute_intensity(data, psi, scan, probe) - # Use the average gradient for all probe positions - return self.xp.mean( - self.adj_probe( - farplane=getattr(objective, f'{model}_grad')( - data, - farplane[..., mode, :, :], - intensity, - ), - psi=psi, - scan=scan, - overwrite=True, - ), - axis=0, - keepdims=True, - ) diff --git a/tests/operators/test_ptycho.py b/tests/operators/test_ptycho.py index ac5eccf3..9adb9263 100644 --- a/tests/operators/test_ptycho.py +++ b/tests/operators/test_ptycho.py @@ -16,7 +16,7 @@ __docformat__ = 'restructuredtext en' -class TestPtycho(unittest.TestCase, OperatorTests): +class TestPtycho(unittest.TestCase): """Test the ptychography operator.""" def setUp(self, depth=1, pw=15, nscan=27): @@ -46,46 +46,33 @@ def setUp(self, depth=1, pw=15, nscan=27): self.operator.__enter__() self.xp = self.operator.xp - self.m = self.xp.asarray(original) - self.m_name = 'psi' - self.kwargs = { - 'scan': self.xp.asarray(scan), - 'probe': self.xp.asarray(probe) + self.mkwargs = { + "scan": self.xp.asarray(scan), + "probe": self.xp.asarray(probe), + "psi": self.xp.asarray(original), } - - self.m1 = self.xp.asarray(probe) - self.m1_name = 'probe' - self.kwargs1 = { - 'scan': self.xp.asarray(scan), - 'psi': self.xp.asarray(original) - } - self.kwargs2 = { - 'scan': self.xp.asarray(scan), + self.dkwargs = { + "farplane": self.xp.asarray(farplane), } - self.d = self.xp.asarray(farplane) - self.d_name = 'farplane' - - def test_adjoint_probe(self): + def test_adjoint(self): """Check that the adjoint operator is correct.""" - d = self.operator.fwd(**{self.m1_name: self.m1}, **self.kwargs1) - assert d.shape == self.d.shape - m = self.operator.adj_probe(**{self.d_name: self.d}, **self.kwargs1) - assert m.shape == self.m1.shape - a = tike.linalg.inner(d, self.d) - b = tike.linalg.inner(self.m1, m) + d = self.operator.fwd(**self.mkwargs) + assert d.shape == self.dkwargs["farplane"].shape + m0, m1 = self.operator.adj(**self.dkwargs, **self.mkwargs) + assert m0.shape == self.mkwargs["psi"].shape + assert m1.shape == self.mkwargs["probe"].shape + a = tike.linalg.inner(d, self.dkwargs["farplane"]) + b = tike.linalg.inner(self.mkwargs["psi"], m0) + c = tike.linalg.inner(self.mkwargs["probe"], m1) print() - print(' = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item())) - print('< d, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item())) + print(" = {:.5g}{:+.5g}j".format(a.real.item(), a.imag.item())) + print("< m0, F*d> = {:.5g}{:+.5g}j".format(b.real.item(), b.imag.item())) + print("< m1, F*d> = {:.5g}{:+.5g}j".format(c.real.item(), c.imag.item())) self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) - - def test_adj_probe_time(self): - """Time the adjoint operation.""" - start = time.perf_counter() - m = self.operator.adj_probe(**{self.d_name: self.d}, **self.kwargs1) - elapsed = time.perf_counter() - start - print(f"\n{elapsed:1.3e} seconds") + self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0) + self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) @unittest.skip('FIXME: This operator is not scaled.') def test_scaled(self): From 084160042b967a85ca823e069d7311ef733e2a02 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 5 Jun 2024 14:14:41 -0500 Subject: [PATCH 6/9] STY: Address linter complaints --- tests/operators/test_multislice.py | 61 +++++++++++++++--------------- tests/operators/test_ptycho.py | 14 +++---- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/tests/operators/test_multislice.py b/tests/operators/test_multislice.py index 5e8c7373..b7acb684 100644 --- a/tests/operators/test_multislice.py +++ b/tests/operators/test_multislice.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import time import unittest import numpy as np @@ -9,15 +8,15 @@ import tike.precision import tike.linalg -from .util import random_complex, OperatorTests +from .util import random_complex __author__ = "Daniel Ching" __copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC." -__docformat__ = 'restructuredtext en' +__docformat__ = "restructuredtext en" class TestMultiSlice(unittest.TestCase): - """Test the ptychography operator.""" + """Test the MultiSlice operator.""" def setUp(self, depth=7, pw=15, nscan=27): """Load a dataset for reconstruction.""" @@ -30,8 +29,9 @@ def setUp(self, depth=7, pw=15, nscan=27): print(Multislice) np.random.seed(0) - scan = np.random.rand(*self.scan_shape).astype( - tike.precision.floating) * (127 - 16) + scan = np.random.rand(*self.scan_shape).astype(tike.precision.floating) * ( + 127 - 16 + ) probe = random_complex(*self.probe_shape) original = random_complex(*self.original_shape) farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape) @@ -48,40 +48,40 @@ def setUp(self, depth=7, pw=15, nscan=27): self.xp = self.operator.xp self.mkwargs = { - 'probe': self.xp.asarray(probe), - 'psi': self.xp.asarray(original), - 'scan': self.xp.asarray(scan), + "probe": self.xp.asarray(probe), + "psi": self.xp.asarray(original), + "scan": self.xp.asarray(scan), } self.dkwargs = { - 'nearplane': self.xp.asarray(farplane), + "nearplane": self.xp.asarray(farplane), } def test_adjoint(self): """Check that the adjoint operator is correct.""" d = self.operator.fwd(**self.mkwargs) - assert d.shape == self.dkwargs['nearplane'].shape + assert d.shape == self.dkwargs["nearplane"].shape m0, m1 = self.operator.adj(**self.dkwargs, **self.mkwargs) - assert m0.shape == self.mkwargs['psi'].shape - assert m1.shape == self.mkwargs['probe'].shape - a = tike.linalg.inner(d, self.dkwargs['nearplane']) - b = tike.linalg.inner(self.mkwargs['psi'], m0) - c = tike.linalg.inner(self.mkwargs['probe'], m1) + assert m0.shape == self.mkwargs["psi"].shape + assert m1.shape == self.mkwargs["probe"].shape + a = tike.linalg.inner(d, self.dkwargs["nearplane"]) + b = tike.linalg.inner(self.mkwargs["psi"], m0) + c = tike.linalg.inner(self.mkwargs["probe"], m1) print() - print(' = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item())) - print('< m0, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item())) - print('< m1, F*d> = {:.5g}{:+.5g}j'.format(c.real.item(), c.imag.item())) + print(" = {:.5g}{:+.5g}j".format(a.real.item(), a.imag.item())) + print("< m0, F*d> = {:.5g}{:+.5g}j".format(b.real.item(), b.imag.item())) + print("< m1, F*d> = {:.5g}{:+.5g}j".format(c.real.item(), c.imag.item())) self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) - @unittest.skip('FIXME: This operator is not scaled.') + @unittest.skip("FIXME: This operator is not scaled.") def test_scaled(self): pass class TestSingleSlice(TestMultiSlice): - """Test the ptychography operator.""" + """Test the SingleSlice operator.""" def setUp(self, depth=1, pw=15, nscan=27): """Load a dataset for reconstruction.""" @@ -91,16 +91,17 @@ def setUp(self, depth=1, pw=15, nscan=27): self.detector_shape = (pw, pw) self.original_shape = (depth, 128, 128) self.scan_shape = (nscan, 2) - print(Multislice) + print(SingleSlice) np.random.seed(0) - scan = np.random.rand(*self.scan_shape).astype( - tike.precision.floating) * (127 - 16) + scan = np.random.rand(*self.scan_shape).astype(tike.precision.floating) * ( + 127 - 16 + ) probe = random_complex(*self.probe_shape) original = random_complex(*self.original_shape) farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape) - self.operator = Multislice( + self.operator = SingleSlice( nscan=self.scan_shape[-2], probe_shape=self.probe_shape[-1], detector_shape=self.detector_shape[-1], @@ -112,14 +113,14 @@ def setUp(self, depth=1, pw=15, nscan=27): self.xp = self.operator.xp self.mkwargs = { - 'probe': self.xp.asarray(probe), - 'psi': self.xp.asarray(original), - 'scan': self.xp.asarray(scan), + "probe": self.xp.asarray(probe), + "psi": self.xp.asarray(original), + "scan": self.xp.asarray(scan), } self.dkwargs = { - 'nearplane': self.xp.asarray(farplane), + "nearplane": self.xp.asarray(farplane), } -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/operators/test_ptycho.py b/tests/operators/test_ptycho.py index 9adb9263..e292efe6 100644 --- a/tests/operators/test_ptycho.py +++ b/tests/operators/test_ptycho.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import time import unittest import numpy as np @@ -9,11 +8,11 @@ import tike.precision import tike.linalg -from .util import random_complex, OperatorTests +from .util import random_complex __author__ = "Daniel Ching" __copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC." -__docformat__ = 'restructuredtext en' +__docformat__ = "restructuredtext en" class TestPtycho(unittest.TestCase): @@ -30,8 +29,9 @@ def setUp(self, depth=1, pw=15, nscan=27): print(Ptycho) np.random.seed(0) - scan = np.random.rand(*self.scan_shape).astype( - tike.precision.floating) * (127 - 16) + scan = np.random.rand(*self.scan_shape).astype(tike.precision.floating) * ( + 127 - 16 + ) probe = random_complex(*self.probe_shape) original = random_complex(*self.original_shape) farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape) @@ -74,10 +74,10 @@ def test_adjoint(self): self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) - @unittest.skip('FIXME: This operator is not scaled.') + @unittest.skip("FIXME: This operator is not scaled.") def test_scaled(self): pass -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From aa0ca3948471632b77a2556c560e57ef3942179b Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 5 Jun 2024 14:16:57 -0500 Subject: [PATCH 7/9] DOC: Update a developer note --- src/tike/operators/cupy/multislice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tike/operators/cupy/multislice.py b/src/tike/operators/cupy/multislice.py index 53923bed..9b21a2f4 100644 --- a/src/tike/operators/cupy/multislice.py +++ b/src/tike/operators/cupy/multislice.py @@ -124,7 +124,7 @@ def adj( scan=scan, psi=psi[s], ) - # FIXME: Why is division by nslices needed here? + # FIXME: Why does correct adjoint require division by nslices? return psi_adj / self.nslices, probe_adj @property From cba8bc28d6f919f4056cbbea9258c14a6e5cc639 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 5 Jun 2024 14:23:23 -0500 Subject: [PATCH 8/9] REF: Revert changes to DM algorithm --- src/tike/ptycho/solvers/dm.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index 30d39eef..48b3851a 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -215,19 +215,27 @@ def keep_some_args_constant( nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] - grad_psi, grad_probe = op.diffraction.adj( - nearplane=nearplane[:, 0], - probe=varying_probe[:, 0], - scan=scan[lo:hi], - psi=psi, - ) - grad_probe = cp.sum( - grad_probe[:, None], - axis=-5, - keepdims=True, - ) - psi_update_numerator += grad_psi - probe_update_numerator += grad_probe + if object_options: + grad_psi = (cp.conj(varying_probe) * nearplane).reshape( + (hi - lo) * probe.shape[-3], *probe.shape[-2:]) + psi_update_numerator[0] = op.diffraction.patch.adj( + patches=grad_psi, + images=psi_update_numerator[0], + positions=scan[lo:hi], + nrepeat=probe.shape[-3], + ) + + if probe_options: + patches = op.diffraction.patch.fwd( + patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), + images=psi[0], + positions=scan[lo:hi], + )[..., None, None, :, :] + probe_update_numerator += cp.sum( + cp.conj(patches) * nearplane, + axis=-5, + keepdims=True, + ) tike.communicators.stream.stream_and_modify2( f=keep_some_args_constant, From c73c3268add36631ace7722e4d4663a0286f0ec1 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 5 Jun 2024 14:31:54 -0500 Subject: [PATCH 9/9] REF: Use the fresnel spectrum propagator as multislice propagator --- src/tike/operators/cupy/multislice.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/tike/operators/cupy/multislice.py b/src/tike/operators/cupy/multislice.py index 9b21a2f4..2a128af5 100644 --- a/src/tike/operators/cupy/multislice.py +++ b/src/tike/operators/cupy/multislice.py @@ -6,22 +6,24 @@ import typing import numpy.typing as npt import numpy as np -import cupy as cp from .operator import Operator +from .fresnelspectprop import FresnelSpectProp from .propagation import Propagation, ZeroPropagation from .convolution import Convolution class Multislice(Operator): + """Multiple slice wavefield propgation""" + def __init__( self, detector_shape: int, probe_shape: int, nz: int, n: int, - propagation: typing.Type[Propagation] = Propagation, + propagation: typing.Type[Propagation] = FresnelSpectProp, diffraction: typing.Type[Convolution] = Convolution, norm: str = "ortho", nslices: int = 1, @@ -62,7 +64,7 @@ def fwd( psi: npt.NDArray[np.csingle], **kwargs, ) -> npt.NDArray[np.csingle]: - """Please see help(SingleSlice) for more info.""" + """Please see help(Multislice) for more info.""" assert psi.shape[0] == self.nslices and psi.ndim == 3 exitwave = self.diffraction.fwd( psi=psi[0], @@ -86,7 +88,7 @@ def adj( overwrite: bool = False, **kwargs, ) -> npt.NDArray[np.csingle]: - """Please see help(SingleSlice) for more info.""" + """Please see help(Multislice) for more info.""" psi_adj = self.xp.zeros_like(psi) probes = [ None,