Skip to content

Commit

Permalink
GI: added resolution estimation and mask binning support, upgraded to…
Browse files Browse the repository at this point in the history
… Python 3.9

Signed-off-by: Nicola VIGANO <[email protected]>
  • Loading branch information
Obi-Wan committed Mar 13, 2024
1 parent 17e01ce commit 52ee7e8
Showing 1 changed file with 77 additions and 14 deletions.
91 changes: 77 additions & 14 deletions corrct/struct_illum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Provide structured illumination support.
Expand All @@ -9,24 +8,27 @@
and ESRF - The European Synchrotron, Grenoble, France
"""

import numpy as np
from scipy import linalg as spalg

from numpy.typing import NDArray, DTypeLike
from typing import Union, Sequence, Optional, Tuple
import copy as cp

from abc import abstractmethod, ABC
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
import scipy.linalg as spalg

from . import operators
from numpy.typing import DTypeLike, NDArray

from tqdm.auto import tqdm

from . import operators, processing

import copy as cp

NDArrayInt = NDArray[np.integer]


def reorder_masks(masks: NDArray, buckets: NDArray, shift: int) -> Tuple[NDArray, NDArray]:
def reorder_masks(masks: NDArray, buckets: NDArray, shift: int) -> tuple[NDArray, NDArray]:
"""Reorder masks, with a simple rot-n algorithm.
Parameters
Expand All @@ -52,7 +54,7 @@ def reorder_masks(masks: NDArray, buckets: NDArray, shift: int) -> Tuple[NDArray
return masks.reshape(masks_shape), buckets


def decompose_qr_masks(masks: NDArray, verbose: bool = False) -> Tuple[NDArray, NDArray]:
def decompose_qr_masks(masks: NDArray, verbose: bool = False) -> tuple[NDArray, NDArray]:
"""Compute QR decomposition of the given masks.
Parameters
Expand Down Expand Up @@ -83,6 +85,47 @@ def decompose_qr_masks(masks: NDArray, verbose: bool = False) -> Tuple[NDArray,
return Qt.reshape(masks_shape), R1t


def estimate_resolution(masks: NDArray, verbose: bool = True, plot_result: bool = True) -> tuple[float, float]:
"""Estimate the mask collection resolution through auto-correlation.
Parameters
----------
masks : NDArray
The list of encoding masks
verbose : bool, optional
Whether to produce verbose output, by default True
plot_result : bool, optional
Whether to plot the results, by default True
Returns
-------
tuple[float, float]
The mean and minimum HWHM of the auto-correlation functions for all the masks.
"""
masks = masks.reshape([-1, *masks.shape[-2:]])

resolutions = np.zeros(len(masks))
for ind, mask in enumerate(tqdm(masks, desc="Computing auto-correlations", disable=not verbose)):
_, auto_corr = processing.misc.norm_cross_corr(mask, plot=False)
point = processing.misc.lines_intersection(auto_corr, 0.5, position="first")
resolutions[ind] = point[0] if point is not None else 0

res_mean = resolutions.mean()
res_min = resolutions.min()

if plot_result:
fig, axs = plt.subplots(1, 1, figsize=[9, 4])
axs.plot(resolutions, label="HWHM auto-correlation")
axs.hlines(res_mean, 0, len(resolutions), colors=["C1"], label=f"Mean ({res_mean:.3})")
axs.hlines(res_min, 0, len(resolutions), colors=["C2"], label=f"Min ({res_min:.3})")
axs.grid()
axs.legend()
fig.tight_layout()
plt.show(block=False)

return res_mean, res_min


class MaskCollection:
"""Define mask collection class."""

Expand All @@ -97,7 +140,7 @@ class MaskCollection:
def __init__(
self,
masks_enc: NDArray,
masks_dec: Optional[NDArray] = None,
masks_dec: Union[NDArray, None] = None,
mask_dims: int = 2,
mask_type: str = "measured",
mask_support: Union[None, Sequence[int], NDArrayInt] = None,
Expand Down Expand Up @@ -220,7 +263,7 @@ def get_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool =

return mask[tuple(mask_inds_vu)]

def get_QR_decomposition(self, buckets: NDArray, shift: int = 0) -> Tuple["MaskCollection", NDArray]:
def get_QR_decomposition(self, buckets: NDArray, shift: int = 0) -> tuple["MaskCollection", NDArray]:
"""Compute and return the QR decomposition of the masks.
Parameters
Expand Down Expand Up @@ -254,6 +297,26 @@ def get_QR_decomposition(self, buckets: NDArray, shift: int = 0) -> Tuple["MaskC

return new_masks, new_buckets

def bin_masks(self, binning: float) -> "MaskCollection":
"""Bin the masks.
Parameters
----------
binning : float
The binning size.
Returns
-------
MaskCollection
A new collection of binned masks.
"""
new_masks = cp.deepcopy(self)

new_masks.masks_enc = processing.pre.bin_imgs(new_masks.masks_enc, binning=binning)
new_masks.masks_dec = processing.pre.bin_imgs(new_masks.masks_dec, binning=binning)

return new_masks

def inspect_masks(self, mask_inds_vu: Union[None, Sequence[int], NDArrayInt] = None):
"""Inspect the encoding and decoding masks at the requested shifts.
Expand Down Expand Up @@ -505,7 +568,7 @@ def get_random_shifts(self, num_shifts: int, axes_order: Sequence[int] = (-2, -1
return [disp[perms] for disp in disps]

def get_sequential_shifts(
self, num_shifts: Optional[int] = None, axes_order: Sequence[int] = (-2, -1)
self, num_shifts: Union[int, None] = None, axes_order: Sequence[int] = (-2, -1)
) -> Sequence[NDArray]:
"""Produce shifts for the "sequential" shift type.
Expand Down

0 comments on commit 52ee7e8

Please sign in to comment.