Skip to content

Commit

Permalink
Merge branch 'main' into object-leading-dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
carterbox committed May 28, 2024
2 parents 48b62a3 + 54895ec commit 81a7ffd
Show file tree
Hide file tree
Showing 12 changed files with 476 additions and 59 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/apptainer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This workflow builds an apptainer with tike installed

name: Publish Apptainer

on:
workflow_dispatch:
release:
types: [published]
push:
branches: [main]

permissions:
contents: read
packages: write

jobs:

publish-apptainer-to-ghcr:
runs-on: ubuntu-latest
strategy:
matrix:
cuda-version:
- "11.8"
- "12.0"
target-arch:
- "x86_64"
- "aarch64"
steps:
- uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
platforms: arm64
- uses: eWaterCycle/setup-apptainer@v2
with:
apptainer-version: 1.3.0
- name: Build container from definition
run: >
apptainer build
--build-arg cuda_version=${{ matrix.cuda-version }}
--build-arg target_arch=${{ matrix.target-arch }}
--build-arg pkg_version=${{ github.ref_name }}
apptainer.sif
apptainer/${{ github.event.repository.name }}.def
- name: Upload to container registry
run: |
echo ${{ secrets.GITHUB_TOKEN }} | apptainer registry login -u ${{ github.actor }} --password-stdin oras://ghcr.io
apptainer push apptainer.sif oras://ghcr.io/${GITHUB_REPOSITORY,,}:${{ github.ref_name }}-${{ matrix.target-arch }}-cuda${{ matrix.cuda-version }}
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/tests/result/
archive/

# apptainer images
*.sif

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
28 changes: 28 additions & 0 deletions apptainer/tike.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Bootstrap: docker
From: registry.fedoraproject.org/fedora-minimal:40-{{ target_arch }}

%arguments
target_arch=x86_64
cuda_version=12.0
pkg_version=main

%post
curl -L -o conda-installer.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-{{ target_arch }}.sh
bash conda-installer.sh -b -p "/opt/miniconda"
rm conda-installer.sh
/opt/miniconda/bin/conda install unzip --yes
curl -L -o source.zip https://github.com/AdvancedPhotonSource/tike/archive/{{ pkg_version }}.zip
/opt/miniconda/bin/unzip source.zip
rm source.zip
cd tike*
CONDA_OVERRIDE_CUDA={{ cuda_version }} /opt/miniconda/bin/conda install cuda-version={{ cuda_version }} --file requirements.txt --file requirements-container.txt -c conda-forge --yes
/opt/miniconda/bin/conda clean --all --yes
/opt/miniconda/bin/pip install . --no-deps --no-build-isolation
/opt/miniconda/bin/pip check
cd ..
rm tike* -rf
cd /opt/miniconda
rm -r man cmake lib/cmake lib/pkgconfig include share var

%runscript
/opt/miniconda/bin/python "$@"
5 changes: 5 additions & 0 deletions requirements-container.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
h5py
hdf5plugin
scikit-image
scipy
toml
180 changes: 169 additions & 11 deletions src/tike/operators/cupy/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .operator import Operator
from .patch import Patch
from .shift import Shift


class Convolution(Operator):
Expand Down Expand Up @@ -40,6 +41,7 @@ class Convolution(Operator):
first, horizontal coordinates second.
"""

def __init__(self, probe_shape, nz, n, ntheta=None,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
Expand All @@ -65,14 +67,22 @@ def fwd(self, psi, scan, probe):
if self.detector_shape == self.probe_shape:
patches = self.xp.empty_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * probe.shape[-3],
self.detector_shape, self.detector_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
else:
patches = self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * probe.shape[-3],
self.detector_shape, self.detector_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
patches = self.patch.fwd(
patches=patches,
Expand All @@ -81,8 +91,12 @@ def fwd(self, psi, scan, probe):
patch_width=self.probe_shape,
nrepeat=probe.shape[-3],
)
patches = patches.reshape((*scan.shape[:-1], probe.shape[-3],
self.detector_shape, self.detector_shape))
patches = patches.reshape((
*scan.shape[:-1],
probe.shape[-3],
self.detector_shape,
self.detector_shape,
))
patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches

Expand All @@ -101,9 +115,11 @@ def adj(self, nearplane, scan, probe, psi=None, overwrite=False):
)
assert psi.shape[:-2] == scan.shape[:-2]
return self.patch.adj(
patches=nearplane.reshape(
(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:])),
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=psi,
positions=scan,
patch_width=self.probe_shape,
Expand All @@ -117,8 +133,12 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False):
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
patches = self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
self.probe_shape, self.probe_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
self.probe_shape,
self.probe_shape,
),
)
patches = self.patch.fwd(
patches=patches,
Expand All @@ -132,3 +152,141 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False):
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]
return patches

class ConvolutionFFT(Operator):
"""A 2D Convolution operator with linear interpolation.
Compute the product two arrays at specific relative positions.
Attributes
----------
nscan : int
The number of scan positions at each angular view.
probe_shape : int
The pixel width and height of the (square) probe illumination.
nz, n : int
The pixel width and height of the reconstructed grid.
ntheta : int
The number of angular partitions of the data.
Parameters
----------
psi : (..., nz, n) complex64
The complex wavefront modulation of the object.
probe : complex64
The (..., nscan, nprobe, probe_shape, probe_shape) or
(..., 1, nprobe, probe_shape, probe_shape) complex illumination
function.
nearplane: complex64
The (...., nscan, nprobe, probe_shape, probe_shape)
wavefronts after exiting the object.
scan : (..., nscan, 2) float32
Coordinates of the minimum corner of the probe grid for each
measurement in the coordinate system of psi. Vertical coordinates
first, horizontal coordinates second.
"""

def __init__(self, probe_shape, nz, n, ntheta=None,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
self.nz = nz
self.n = n
if detector_shape is None:
self.detector_shape = probe_shape
else:
self.detector_shape = detector_shape
self.pad = (self.detector_shape - self.probe_shape) // 2
self.end = self.probe_shape + self.pad
self.patch = Patch()
self.shift = Shift()

def __enter__(self):
self.shift.__enter__()
return self

def __exit__(self, type, value, traceback):
self.shift.__exit__(type, value, traceback)

def fwd(self, psi, scan, probe):
"""Extract probe shaped patches from the psi at each scan position.
The patches within the bounds of psi are linearly interpolated, and
indices outside the bounds of psi are not allowed.
"""
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[:-4] == scan.shape[:-2], (probe.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
if self.detector_shape == self.probe_shape:
patches = self.xp.empty_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
else:
patches = self.xp.zeros_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
index, shift = self.xp.divmod(scan, 1.0)
shift = shift.reshape((*scan.shape[:-1], 1, 2))

patches = self.patch.fwd(
patches=patches,
images=psi,
positions=index,
patch_width=self.probe_shape,
nrepeat=probe.shape[-3],
)

patches = patches.reshape((
*scan.shape[:-1],
probe.shape[-3],
self.detector_shape,
self.detector_shape,
))
patches = self.shift.adj(patches, shift, overwrite=False)

patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches

def adj(self, nearplane, scan, probe, psi=None, overwrite=False):
"""Combine probe shaped patches into a psi shaped grid by addition."""
assert probe.shape[:-4] == scan.shape[:-2]
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1]
if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()

index, shift = self.xp.divmod(scan, 1.0)
shift = shift.reshape((*scan.shape[:-1], 1, 2))

nearplane = self.shift.fwd(nearplane, shift, overwrite=True)

if psi is None:
psi = self.xp.zeros_like(
nearplane,
shape=(*scan.shape[:-2], self.nz, self.n),
)
assert psi.shape[:-2] == scan.shape[:-2]
return self.patch.adj(
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=psi,
positions=index,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
8 changes: 5 additions & 3 deletions src/tike/operators/cupy/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fwd(self, a, shift, overwrite=False, cval=None):
if shift is None:
return a
shape = a.shape
padded = a.reshape(-1, *shape[-2:])
padded = a.reshape(*shape)
padded = self._fft2(
padded,
axes=(-2, -1),
Expand All @@ -33,8 +33,10 @@ def fwd(self, a, shift, overwrite=False, cval=None):
self.xp.fft.fftfreq(padded.shape[-2]).astype(shift.dtype),
)
padded *= self.xp.exp(
-2j * self.xp.pi *
(x * shift[..., 1, None, None] + y * shift[..., 0, None, None]))
-2j
* self.xp.pi
* (x * shift[..., 1, None, None] + y * shift[..., 0, None, None])
)
padded = self._ifft2(padded, axes=(-2, -1), overwrite_x=True)
return padded.reshape(*shape)

Expand Down
Loading

0 comments on commit 81a7ffd

Please sign in to comment.