Skip to content

Commit

Permalink
API: Add depth dimension to psi
Browse files Browse the repository at this point in the history
  • Loading branch information
carterbox committed May 25, 2024
1 parent 6326998 commit 0efb364
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 207 deletions.
70 changes: 0 additions & 70 deletions src/tike/operators/cupy/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,73 +132,3 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False):
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]
return patches

def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False):
"""Peform adj and adj_probe at the same time."""
assert probe.shape[:-4] == scan.shape[:-2]
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1], (nearplane.shape,
scan.shape)

patches = self.patch.fwd(
# Could be xp.empty if scan positions are all in bounds
patches=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
self.probe_shape, self.probe_shape),
),
images=psi,
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
patches = patches.reshape((*scan.shape[:-1], nearplane.shape[-3],
self.probe_shape, self.probe_shape))
if rpie:
patches_amp = self.xp.sum(
patches * patches.conj(),
axis=-4,
keepdims=True,
)
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]

if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()
if rpie:
probe_amp = probe * probe.conj()
probe_amp = probe_amp.reshape(
(*scan.shape[:-2], -1, *nearplane.shape[-2:])
# (..., nscan * nprobe, probe_shape, probe_shape)
# (..., nprobe, probe_shape, probe_shape)
)
probe_amp = self.patch.adj(
patches=probe_amp,
images=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], self.nz, self.n),
),
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)

apsi = self.patch.adj(
patches=nearplane.reshape(
(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:])),
images=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], self.nz, self.n),
),
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)

if rpie:
return apsi, patches, patches_amp, probe_amp
else:
return apsi, patches
15 changes: 13 additions & 2 deletions src/tike/operators/cupy/multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def fwd(
**kwargs,
) -> npt.NDArray[np.csingle]:
"""Please see help(SingleSlice) for more info."""
assert (psi.shape[0] == 1 and psi.ndim == 3)
return self.diffraction.fwd(
psi=psi,
psi=psi[0],
scan=scan,
probe=probe,
)
Expand All @@ -113,12 +114,22 @@ def adj(
**kwargs,
) -> npt.NDArray[np.csingle]:
"""Please see help(SingleSlice) for more info."""
assert psi is None or (psi.shape[0] == 1 and psi.ndim == 3)
return self.diffraction.adj(
nearplane=nearplane,
probe=probe,
scan=scan,
overwrite=True,
psi=psi,
psi=psi[0] if psi is not None else None,
)[None, ...]

def adj_probe(self, nearplane, scan, psi, overwrite=False):
assert (psi.shape[0] == 1 and psi.ndim == 3)
return self.diffraction.adj_probe(
nearplane=nearplane,
scan=scan,
psi=psi[0],
overwrite=overwrite,
)

@property
Expand Down
38 changes: 8 additions & 30 deletions src/tike/operators/cupy/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,31 @@ class Ptycho(Operator):
----------
detector_shape : int
The pixel width and height of the (square) detector grid.
nz, n : int
The pixel width and height of the reconstructed grid.
d, nz, n : int
The pixel depth, width, and height of the reconstructed grid.
probe_shape : int
The pixel width and height of the (square) probe illumination.
propagation : :py:class:`Operator`
The wave propagation operator being used.
diffraction : :py:class:`Operator`
The object probe interaction operator being used.
data : (..., FRAME, WIDE, HIGH) float32
data : (FRAME, WIDE, HIGH) float32
The intensity (square of the absolute value) of the propagated
wavefront; i.e. what the detector records.
farplane: (..., POSI, 1, SHARED, detector_shape, detector_shape) complex64
farplane: (POSI, 1, SHARED, detector_shape, detector_shape) complex64
The wavefronts hitting the detector respectively.
probe : {(..., 1, 1, SHARED, WIDE, HIGH), (..., POSI, 1, SHARED, WIDE, HIGH)} complex64
probe : {(1, 1, SHARED, WIDE, HIGH), (POSI, 1, SHARED, WIDE, HIGH)} complex64
The complex illumination function.
psi : (..., WIDE, HIGH) complex64
psi : (DEPTH, WIDE, HIGH) complex64
The wavefront modulation coefficients of the object.
scan : (..., POSI, 2) float32
scan : (POSI, 2) float32
Coordinates of the minimum corner of the probe grid for each
measurement in the coordinate system of psi. Coordinate order
consistent with WIDE, HIGH order.
.. versionchanged:: 0.25.0 Removed the model and ntheta parameters.
.. versionchanged:: 0.26.0 Added depth dimension to psi array
"""

Expand Down Expand Up @@ -246,26 +247,3 @@ def grad_probe(
axis=0,
keepdims=True,
)

def adj_all(
self,
farplane: npt.NDArray[np.csingle],
probe: npt.NDArray[np.csingle],
scan: npt.NDArray[np.single],
psi: npt.NDArray[np.csingle],
overwrite: bool = False,
rpie: bool = False,
) -> typing.Tuple[npt.NDArray, ...]:
"""Please see help(Ptycho) for more info."""
result = self.diffraction.adj_all(
nearplane=self.propagation.adj(
farplane,
overwrite=overwrite,
)[..., 0, :, :, :],
probe=probe[..., 0, :, :, :],
scan=scan,
overwrite=True,
psi=psi,
rpie=rpie,
)
return (result[0], result[1][..., None, :, :, :], *result[2:])
2 changes: 0 additions & 2 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def _compute_intensity(
eigen_probe=None,
fly=1,
):
leading = psi.shape[:-2]
intensity = 0
for m in range(probe.shape[-3]):
farplane = operator.fwd(
Expand All @@ -111,7 +110,6 @@ def _compute_intensity(
)
intensity += np.sum(
np.square(np.abs(farplane)).reshape(
*leading,
scan.shape[-2] // fly,
fly,
operator.detector_shape,
Expand Down
6 changes: 4 additions & 2 deletions src/tike/ptycho/solvers/_preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def _psi_preconditioner(
operator: tike.operators.Ptycho,
) -> npt.NDArray:

# FIXME: Generated only one preconditioner for all slices
psi_update_denominator = cp.zeros(
shape=psi.shape,
shape=psi.shape[-2:],
dtype=psi.dtype,
)

Expand Down Expand Up @@ -92,8 +93,9 @@ def make_certain_args_constant(
) -> None:
nonlocal probe_update_denominator

# FIXME: Only use the first slice for the probe preconditioner
patches = operator.diffraction.patch.fwd(
images=psi,
images=psi[0],
positions=scan[lo:hi],
patch_width=probe.shape[-1],
)
Expand Down
16 changes: 7 additions & 9 deletions src/tike/ptycho/solvers/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,22 @@ def keep_some_args_constant(
nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end,
pad:end]

patches = op.diffraction.patch.fwd(
patches=cp.zeros_like(nearplane[..., 0, 0, :, :]),
images=psi,
positions=scan[lo:hi],
)[..., None, None, :, :]

if object_options:

grad_psi = (cp.conj(varying_probe) * nearplane).reshape(
(hi - lo) * probe.shape[-3], *probe.shape[-2:])
psi_update_numerator = op.diffraction.patch.adj(
psi_update_numerator[0] = op.diffraction.patch.adj(
patches=grad_psi,
images=psi_update_numerator,
images=psi_update_numerator[0],
positions=scan[lo:hi],
nrepeat=probe.shape[-3],
)

if probe_options:
patches = op.diffraction.patch.fwd(
patches=cp.zeros_like(nearplane[..., 0, 0, :, :]),
images=psi[0],
positions=scan[lo:hi],
)[..., None, None, :, :]
probe_update_numerator += cp.sum(
cp.conj(patches) * nearplane,
axis=-5,
Expand Down
8 changes: 4 additions & 4 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,10 @@ def keep_some_args_constant(
# (24b)
object_update_proj = cp.conj(bunique_probe[blo:bhi]) * bchi[blo:bhi]
# (25b) Common object gradient.
object_upd_sum = op.diffraction.patch.adj(
object_upd_sum[0] = op.diffraction.patch.adj(
patches=object_update_proj.reshape(
len(scan[lo:hi]) * bchi.shape[-3], *bchi.shape[-2:]),
images=object_upd_sum,
images=object_upd_sum[0],
positions=scan[lo:hi],
nrepeat=bchi.shape[-3],
)
Expand All @@ -601,7 +601,7 @@ def keep_some_args_constant(
if recover_probe:
bpatches[blo:bhi] = op.diffraction.patch.fwd(
patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]),
images=psi,
images=psi[0],
positions=scan[lo:hi],
)[..., None, None, :, :]
# (24a)
Expand Down Expand Up @@ -724,7 +724,7 @@ def _precondition_nearplane_gradients(

object_update_proj = op.diffraction.patch.fwd(
patches=cp.zeros_like(nearplane[..., 0, 0, :, :]),
images=object_update_precond,
images=object_update_precond[0],
positions=scan[lo:hi],
)
dOP = object_update_proj[..., None,
Expand Down
6 changes: 3 additions & 3 deletions src/tike/ptycho/solvers/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ def __post_init__(self):
"It should be (1, 1, S, W, H) "
"where S >=1 is the number of probes, and "
"W, H >= 1 are the square probe grid dimensions.")
if (self.psi.ndim != 2 or np.any(
np.asarray(self.psi.shape) <= np.asarray(self.probe.shape[-2:]))
if (self.psi.ndim != 3 or np.any(
np.asarray(self.psi.shape[-2:]) <= np.asarray(self.probe.shape[-2:]))
):
raise ValueError(
f"psi shape {self.psi.shape} is incorrect. "
"It should be (W, H) where W, H > probe.shape[-2:].")
"It should be (D, W, H) where W, H > probe.shape[-2:].")
check_allowed_positions(
self.scan,
self.psi,
Expand Down
6 changes: 3 additions & 3 deletions src/tike/ptycho/solvers/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,9 @@ def keep_some_args_constant(
if object_options:
grad_psi = (cp.conj(unique_probe) * diff / probe.shape[-3]).reshape(
scan[lo:hi].shape[0] * probe.shape[-3], *probe.shape[-2:])
psi_update_numerator = op.diffraction.patch.adj(
psi_update_numerator[0] = op.diffraction.patch.adj(
patches=grad_psi,
images=psi_update_numerator,
images=psi_update_numerator[0],
positions=scan[lo:hi],
nrepeat=probe.shape[-3],
)
Expand All @@ -463,7 +463,7 @@ def keep_some_args_constant(

patches = op.diffraction.patch.fwd(
patches=cp.zeros_like(diff[..., 0, 0, :, :]),
images=psi,
images=psi[0],
positions=scan[lo:hi],
)[..., None, None, :, :]

Expand Down
36 changes: 0 additions & 36 deletions tests/operators/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,5 @@ def test_adj_probe_time(self):
def test_scaled(self):
pass

def test_adjoint_all(self):
"""Check that the adjoint operator is correct."""
d = self.operator.fwd(
**{
self.m_name: self.m,
self.m1_name: self.m1
},
**self.kwargs2,
)
assert d.shape == self.d.shape
m, m1 = self.operator.adj_all(
**{
self.d_name: self.d,
self.m_name: self.m,
self.m1_name: self.m1
},
**self.kwargs2,
)
assert m.shape == self.m.shape
assert m1.shape == self.m1.shape
a = tike.linalg.inner(d, self.d)
b = tike.linalg.inner(self.m, m)
c = tike.linalg.inner(self.m1, m1)
print()
print('< Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(),
a.imag.item()))
print('< d0, F*d0> = {:.6f}{:+.6f}j'.format(b.real.item(),
b.imag.item()))
print('< d1, F*d1> = {:.6f}{:+.6f}j'.format(c.real.item(),
c.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0)
self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-3, atol=0)
self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 0efb364

Please sign in to comment.