Skip to content

Commit

Permalink
Misc fixes (#91)
Browse files Browse the repository at this point in the history
* implement stxm simulator
* fix imports when ptychonn not installed
* update XRF to use position matcher
* prototype propagators
* add defocus distance to disk and rect probe initializers
  • Loading branch information
stevehenke authored Jul 9, 2024
1 parent ee8860d commit f1ea91d
Show file tree
Hide file tree
Showing 34 changed files with 678 additions and 368 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install with minimal dependencies
Expand Down
6 changes: 6 additions & 0 deletions ptychodus/api/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def _radiusX(self) -> float:
def _radiusY(self) -> float:
return self.heightInPixels / 2

def getPixelGeometry(self) -> PixelGeometry:
return PixelGeometry(
widthInMeters=self.pixelWidthInMeters,
heightInMeters=self.pixelHeightInMeters,
)

def mapObjectPointToScanPoint(self, objectPoint: Point2D) -> Point2D:
x = self.centerXInMeters + self.pixelWidthInMeters * (objectPoint.x - self._radiusX)
y = self.centerYInMeters + self.pixelHeightInMeters * (objectPoint.y - self._radiusY)
Expand Down
9 changes: 4 additions & 5 deletions ptychodus/api/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TypeAlias

import numpy

from .geometry import ImageExtent, PixelGeometry
from .typing import ComplexArrayType, RealArrayType

WavefieldArrayType: TypeAlias = ComplexArrayType
from .propagator import WavefieldArrayType
from .typing import RealArrayType


@dataclass(frozen=True)
Expand Down Expand Up @@ -181,7 +179,8 @@ def getCoherence(self) -> float:
return numpy.sqrt(numpy.sum(numpy.square(self._modeRelativePower)))

def getIntensity(self) -> RealArrayType:
return numpy.absolute(self._array).sum(axis=-3)
intensity = numpy.real(self._array * numpy.conjugate(self._array))
return intensity.sum(axis=-3)


class ProbeFileReader(ABC):
Expand Down
184 changes: 184 additions & 0 deletions ptychodus/api/propagator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeAlias

from scipy.fft import fft2, fftfreq, fftshift, ifft2, ifftshift
import numpy

from .typing import ComplexArrayType, RealArrayType

WavefieldArrayType: TypeAlias = ComplexArrayType


@dataclass(frozen=True)
class PropagatorParameters:
wavelength_m: float
'''illumination wavelength in meters'''
width_px: int
'''number of pixels in the x-direction'''
height_px: int
'''number of pixels in the y-direction'''
pixel_width_m: float
'''source plane pixel width in meters'''
pixel_height_m: float
'''source plane pixel height in meters'''
propagation_distance_m: float
'''propagation distance in meters'''

@property
def dx(self) -> float:
'''pixel width in wavelengths'''
return self.pixel_width_m / self.wavelength_m

@property
def pixel_aspect_ratio(self) -> float:
'''pixel aspect ratio (width / height)'''
return self.pixel_width_m / self.pixel_height_m

@property
def z(self) -> float:
'''propagation distance in wavelengths'''
return self.propagation_distance_m / self.wavelength_m

@property
def fresnel_number(self) -> float:
'''fresnel number'''
return numpy.square(self.dx) / self.z

@property
def diffraction_plane_pixel_width(self) -> float:
'''diffraction plane pixel width in wavelengths'''
return self.z / (self.width_px * self.dx)

def get_spatial_coordinates(self) -> tuple[RealArrayType, RealArrayType]:
jj, ii = numpy.mgrid[:self.height_px, :self.width_px]
xx = ii - (self.width_px - 1) / 2
yy = jj - (self.height_px - 1) / 2
return yy, xx

def get_frequency_coordinates(self) -> tuple[RealArrayType, RealArrayType]:
fx = fftshift(fftfreq(self.width_px))
fy = fftshift(fftfreq(self.height_px))
FY, FX = numpy.meshgrid(fy, fx)
return FY, FX


class Propagator(ABC):

@abstractmethod
def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType:
pass


class AngularSpectrumPropagator(Propagator):

def __init__(self, parameters: PropagatorParameters) -> None:
ar = parameters.pixel_aspect_ratio

i2piz = 2j * numpy.pi * parameters.z
FY, FX = parameters.get_frequency_coordinates()
F2 = numpy.square(FX) + numpy.square(ar * FY)
ratio = F2 / numpy.square(parameters.dx)
tf = numpy.exp(i2piz * numpy.sqrt(1 - ratio)),

self._transfer_function = numpy.where(ratio < 1, tf, 0)

def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType:
return fftshift(ifft2(self._transfer_function * fft2(ifftshift(wavefield))))


class FresnelTransferFunctionPropagator(Propagator):

def __init__(self, parameters: PropagatorParameters) -> None:
Fr = parameters.fresnel_number
ar = parameters.pixel_aspect_ratio

i2pi = 2j * numpy.pi
FY, FX = parameters.get_frequency_coordinates()
F2 = numpy.square(FX) + numpy.square(ar * FY)

self._transfer_function = numpy.exp(i2pi * (parameters.z - 0.5 * F2 / Fr))

def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType:
return fftshift(ifft2(self._transfer_function * fft2(ifftshift(wavefield))))


class FresnelTransformPropagator(Propagator):

def __init__(self, parameters: PropagatorParameters) -> None:
Fr = parameters.fresnel_number
ar = parameters.pixel_aspect_ratio

i2piz = 2j * numpy.pi * parameters.z
ipi = 1j * numpy.pi
iar = 1j * ar

YY, XX = parameters.get_spatial_coordinates()
FY, FX = parameters.get_frequency_coordinates()
F2 = numpy.square(FX) + numpy.square(ar * FY)

self._A = numpy.exp(F2 * ipi / Fr) * numpy.exp(i2piz) * Fr / iar
self._B = numpy.exp(ipi * Fr * (numpy.square(XX) + numpy.square(YY / ar)))

def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType:
return self._A * fftshift(fft2(ifftshift(wavefield * self._B)))


class FresnelTransformLegacyPropagator(Propagator):

def __init__(self, parameters: PropagatorParameters) -> None:
self._parameters = parameters

def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType:
dxy = self._parameters.pixel_width_m
z = self._parameters.propagation_distance_m
wavelength = self._parameters.wavelength_m

(M, N) = wavefield.shape
k = 2 * numpy.pi / wavelength

# the coordinate grid
M_grid = numpy.arange(-1 * numpy.floor(M / 2), numpy.ceil(M / 2))
N_grid = numpy.arange(-1 * numpy.floor(N / 2), numpy.ceil(N / 2))
lx = M_grid * dxy
ly = N_grid * dxy

XX, YY = numpy.meshgrid(lx, ly)

# the coordinate grid on the output plane
fc = 1 / dxy
fu = wavelength * z * fc
lu = M_grid * fu / M
lv = N_grid * fu / N
Fx, Fy = numpy.meshgrid(lu, lv)

if z > 0:
pf = numpy.exp(1j * k * z) * numpy.exp(1j * k * (Fx**2 + Fy**2) / 2 / z)
kern = wavefield * numpy.exp(1j * k * (XX**2 + YY**2) / 2 / z)
cgh = fft2(fftshift(kern))
OUT = fftshift(cgh * fftshift(pf))
else:
pf = numpy.exp(1j * k * z) * numpy.exp(1j * k * (XX**2 + YY**2) / 2 / z)
cgh = ifft2(fftshift(wavefield * numpy.exp(1j * k * (Fx**2 + Fy**2) / 2 / z)))
OUT = fftshift(cgh) * pf

return OUT


class FraunhoferPropagator(Propagator):

def __init__(self, parameters: PropagatorParameters) -> None:
Fr = parameters.fresnel_number
ar = parameters.pixel_aspect_ratio

i2piz = 2j * numpy.pi * parameters.z
ipi = 1j * numpy.pi
iar = 1j * ar

FY, FX = parameters.get_frequency_coordinates()
F2 = numpy.square(FX) + numpy.square(ar * FY)

self._A = numpy.exp(F2 * ipi / Fr) * numpy.exp(i2piz) * Fr / iar

def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType:
return self._A * fftshift(fft2(ifftshift(wavefield)))
12 changes: 6 additions & 6 deletions ptychodus/controller/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None:
self._fileDialogFactory)
self._probeController = ProbeController.createInstance(
model.probeRepository, model.probeAPI, self._probeImageController,
model.probePropagator, model.probePropagatorVisualizationEngine, view.probeView,
self._fileDialogFactory)
model.probePropagator, model.probePropagatorVisualizationEngine, model.stxmSimulator,
model.stxmVisualizationEngine, model.exposureAnalyzer,
model.exposureVisualizationEngine, model.fluorescenceEnhancer,
model.fluorescenceVisualizationEngine, view.probeView, self._fileDialogFactory)
self._objectImageController = ImageController.createInstance(
model.objectVisualizationEngine, view.objectImageView, view.statusBar(),
self._fileDialogFactory)
self._objectController = ObjectController.createInstance(
model.objectRepository, model.objectAPI, self._objectImageController,
model.fourierRingCorrelator, model.stxmAnalyzer, model.stxmVisualizationEngine,
model.exposureAnalyzer, model.exposureVisualizationEngine, model.fluorescenceEnhancer,
model.fluorescenceVisualizationEngine, model.xmcdAnalyzer,
model.xmcdVisualizationEngine, view.objectView, self._fileDialogFactory)
model.fourierRingCorrelator, model.xmcdAnalyzer, model.xmcdVisualizationEngine,
view.objectView, self._fileDialogFactory)
self._reconstructorParametersController = ReconstructorController.createInstance(
model.reconstructorPresenter,
model.productRepository,
Expand Down
65 changes: 5 additions & 60 deletions ptychodus/controller/object/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from ptychodus.api.observer import SequenceObserver

from ...model.analysis import (ExposureAnalyzer, FluorescenceEnhancer, FourierRingCorrelator,
STXMAnalyzer, XMCDAnalyzer)
from ...model.analysis import FourierRingCorrelator, XMCDAnalyzer
from ...model.product import ObjectAPI, ObjectRepository
from ...model.product.object import ObjectRepositoryItem
from ...model.visualization import VisualizationEngine
Expand All @@ -16,12 +15,9 @@
from ..data import FileDialogFactory
from ..image import ImageController
from .editorFactory import ObjectEditorViewControllerFactory
from .exposure import ExposureViewController
from .frc import FourierRingCorrelationViewController
from .stxm import STXMViewController
from .treeModel import ObjectTreeModel
from .xmcd import XMCDViewController
from .fluorescence import FluorescenceViewController

logger = logging.getLogger(__name__)

Expand All @@ -30,13 +26,9 @@ class ObjectController(SequenceObserver[ObjectRepositoryItem]):

def __init__(self, repository: ObjectRepository, api: ObjectAPI,
imageController: ImageController, correlator: FourierRingCorrelator,
stxmAnalyzer: STXMAnalyzer, stxmVisualizationEngine: VisualizationEngine,
exposureAnalyzer: ExposureAnalyzer,
exposureVisualizationEngine: VisualizationEngine,
fluorescenceEnhancer: FluorescenceEnhancer,
fluorescenceVisualizationEngine: VisualizationEngine, xmcdAnalyzer: XMCDAnalyzer,
xmcdVisualizationEngine: VisualizationEngine, view: RepositoryTreeView,
fileDialogFactory: FileDialogFactory, treeModel: ObjectTreeModel) -> None:
xmcdAnalyzer: XMCDAnalyzer, xmcdVisualizationEngine: VisualizationEngine,
view: RepositoryTreeView, fileDialogFactory: FileDialogFactory,
treeModel: ObjectTreeModel) -> None:
super().__init__()
self._repository = repository
self._api = api
Expand All @@ -47,32 +39,18 @@ def __init__(self, repository: ObjectRepository, api: ObjectAPI,
self._editorFactory = ObjectEditorViewControllerFactory()

self._frcViewController = FourierRingCorrelationViewController(correlator, treeModel)
self._stxmViewController = STXMViewController(stxmAnalyzer, stxmVisualizationEngine,
fileDialogFactory)
self._exposureViewController = ExposureViewController(exposureAnalyzer,
exposureVisualizationEngine,
fileDialogFactory)
self._fluorescenceViewController = FluorescenceViewController(
fluorescenceEnhancer, fluorescenceVisualizationEngine, fileDialogFactory)
self._xmcdViewController = XMCDViewController(xmcdAnalyzer, xmcdVisualizationEngine,
fileDialogFactory, treeModel)

@classmethod
def createInstance(cls, repository: ObjectRepository, api: ObjectAPI,
imageController: ImageController, correlator: FourierRingCorrelator,
stxmAnalyzer: STXMAnalyzer, stxmVisualizationEngine: VisualizationEngine,
exposureAnalyzer: ExposureAnalyzer,
exposureVisualizationEngine: VisualizationEngine,
fluorescenceEnhancer: FluorescenceEnhancer,
fluorescenceVisualizationEngine: VisualizationEngine,
xmcdAnalyzer: XMCDAnalyzer, xmcdVisualizationEngine: VisualizationEngine,
view: RepositoryTreeView,
fileDialogFactory: FileDialogFactory) -> ObjectController:
# TODO figure out good fix when saving NPY file without suffix (numpy adds suffix)
treeModel = ObjectTreeModel(repository, api)
controller = cls(repository, api, imageController, correlator, stxmAnalyzer,
stxmVisualizationEngine, exposureAnalyzer, exposureVisualizationEngine,
fluorescenceEnhancer, fluorescenceVisualizationEngine, xmcdAnalyzer,
controller = cls(repository, api, imageController, correlator, xmcdAnalyzer,
xmcdVisualizationEngine, view, fileDialogFactory, treeModel)
repository.addObserver(controller)

Expand Down Expand Up @@ -103,15 +81,6 @@ def createInstance(cls, repository: ObjectRepository, api: ObjectAPI,
frcAction = view.buttonBox.analyzeMenu.addAction('Fourier Ring Correlation...')
frcAction.triggered.connect(controller._analyzeFRC)

stxmAction = view.buttonBox.analyzeMenu.addAction('STXM...')
stxmAction.triggered.connect(controller._analyzeSTXM)

exposureAction = view.buttonBox.analyzeMenu.addAction('Exposure...')
exposureAction.triggered.connect(controller._analyzeExposure)

fluorescenceAction = view.buttonBox.analyzeMenu.addAction('Enhance Fluorescence...')
fluorescenceAction.triggered.connect(controller._enhanceFluorescence)

xmcdAction = view.buttonBox.analyzeMenu.addAction('XMCD...')
xmcdAction.triggered.connect(controller._analyzeXMCD)

Expand Down Expand Up @@ -202,30 +171,6 @@ def _analyzeFRC(self) -> None:
else:
self._frcViewController.analyze(itemIndex, itemIndex)

def _analyzeSTXM(self) -> None:
itemIndex = self._getCurrentItemIndex()

if itemIndex < 0:
logger.warning('No current item!')
else:
self._stxmViewController.analyze(itemIndex)

def _analyzeExposure(self) -> None:
itemIndex = self._getCurrentItemIndex()

if itemIndex < 0:
logger.warning('No current item!')
else:
self._exposureViewController.analyze(itemIndex)

def _enhanceFluorescence(self) -> None:
itemIndex = self._getCurrentItemIndex()

if itemIndex < 0:
logger.warning('No current item!')
else:
self._fluorescenceViewController.launch(itemIndex)

def _analyzeXMCD(self) -> None:
itemIndex = self._getCurrentItemIndex()

Expand Down
Loading

0 comments on commit f1ea91d

Please sign in to comment.