Skip to content

Commit

Permalink
Updated docstrings and harmonzied type comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
maurerv committed Jul 31, 2024
1 parent b93a452 commit 928acd7
Show file tree
Hide file tree
Showing 18 changed files with 218 additions and 41 deletions.
186 changes: 186 additions & 0 deletions tme/backends/_jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
""" Utility functions for jax backend.
Copyright (c) 2023-2024 European Molecular Biology Laboratory
Author: Valentin Maurer <[email protected]>
"""
from typing import Tuple
from functools import partial

import jax.numpy as jnp
from jax import pmap, lax

from ..types import BackendArray
from ..backends import backend as be
from ..matching_utils import normalize_template as _normalize_template


def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
"""
Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
"""
template_ft = jnp.fft.rfftn(template)
template_ft = template_ft.at[:].multiply(ft_target)
correlation = jnp.fft.irfftn(template_ft)
return correlation


def _flc_scoring(
template: BackendArray,
template_mask: BackendArray,
ft_target: BackendArray,
ft_target2: BackendArray,
n_observations: BackendArray,
eps: float,
**kwargs,
) -> BackendArray:
"""
Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
"""
correlation = _correlate(template=template, ft_target=ft_target)
inv_denominator = _reciprocal_target_std(
ft_target=ft_target,
ft_target2=ft_target2,
template_mask=template_mask,
eps=eps,
n_observations=n_observations,
)
correlation = correlation.at[:].multiply(inv_denominator)
return correlation


def _flcSphere_scoring(
template: BackendArray,
ft_target: BackendArray,
inv_denominator: BackendArray,
**kwargs,
) -> BackendArray:
"""
Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
"""
correlation = _correlate(template=template, ft_target=ft_target)
correlation = correlation.at[:].multiply(inv_denominator)
return correlation


def _reciprocal_target_std(
ft_target: BackendArray,
ft_target2: BackendArray,
template_mask: BackendArray,
n_observations: float,
eps: float,
) -> BackendArray:
"""
Computes reciprocal standard deviation of a target given a mask.
See Also
--------
:py:meth:`tme.matching_exhaustive.flc_scoring`.
"""
ft_template_mask = jnp.fft.rfftn(template_mask)

# E(X^2)- E(X)^2
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
exp_sq = exp_sq.at[:].divide(n_observations)

ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
sq_exp = jnp.fft.irfftn(ft_template_mask)
sq_exp = sq_exp.at[:].divide(n_observations)
sq_exp = sq_exp.at[:].power(2)

exp_sq = exp_sq.at[:].add(-sq_exp)
exp_sq = exp_sq.at[:].max(0)
exp_sq = exp_sq.at[:].power(0.5)

exp_sq = exp_sq.at[:].set(
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
)
return exp_sq


def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
arr_ft = jnp.fft.rfftn(arr)
arr_ft = arr_ft.at[:].multiply(arr_filter)
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))


def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
return arr


@partial(
pmap,
in_axes=(0,) + (None,) * 6,
static_broadcasted_argnums=[6, 7],
)
def scan(
target: BackendArray,
template: BackendArray,
template_mask: BackendArray,
rotations: BackendArray,
template_filter: BackendArray,
target_filter: BackendArray,
fast_shape: Tuple[int],
rotate_mask: bool,
) -> Tuple[BackendArray, BackendArray]:
eps = jnp.finfo(template.dtype).resolution

if hasattr(target_filter, "shape"):
target = _apply_fourier_filter(target, target_filter)

ft_target = jnp.fft.rfftn(target)
ft_target2 = jnp.fft.rfftn(jnp.square(target))
inv_denominator, target, scoring_func = None, None, _flc_scoring
if not rotate_mask:
n_observations = jnp.sum(template_mask)
inv_denominator = _reciprocal_target_std(
ft_target=ft_target,
ft_target2=ft_target2,
template_mask=be.topleft_pad(template_mask, fast_shape),
eps=eps,
n_observations=n_observations,
)
ft_target2, scoring_func = None, _flcSphere_scoring

_template_filter_func = _identity
if template_filter.shape != ():
_template_filter_func = _apply_fourier_filter

def _sample_transform(ret, rotation_matrix):
max_scores, rotations, index = ret
template_rot, template_mask_rot = be.rigid_transform(
arr=template,
arr_mask=template_mask,
rotation_matrix=rotation_matrix,
order=1, # thats all we get for now
)

n_observations = jnp.sum(template_mask_rot)
template_rot = _template_filter_func(template_rot, template_filter)
template_rot = _normalize_template(
template_rot, template_mask_rot, n_observations
)
template_rot = be.topleft_pad(template_rot, fast_shape)
template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)

scores = scoring_func(
template=template_rot,
template_mask=template_mask_rot,
ft_target=ft_target,
ft_target2=ft_target2,
inv_denominator=inv_denominator,
n_observations=n_observations,
eps=eps,
)
max_scores, rotations = be.max_score_over_rotations(
scores, max_scores, rotations, index
)
return (max_scores, rotations, index + 1), None

score_space = jnp.zeros(fast_shape)
rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
(score_space, rotation_space, _), _ = lax.scan(
_sample_transform, (score_space, rotation_space, 0), rotations
)

return score_space, rotation_space
2 changes: 1 addition & 1 deletion tme/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.metadata = metadata

def __repr__(self):
response = "Density object at {}\nOrigin: {}, sampling_rate: {}, Shape: {}"
response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
return response.format(
hex(id(self)),
tuple(np.round(self.origin, 3)),
Expand Down
7 changes: 3 additions & 4 deletions tme/matching_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def _load_array(arr: NDArray):
NDArray
Loaded array.
"""

if type(arr) == np.memmap:
if isinstance(arr, np.memmap):
return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
return arr

Expand Down Expand Up @@ -153,13 +152,13 @@ def subset_array(
arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)

if type(arr) == Density:
if isinstance(arr, Density):
if isinstance(arr.data, np.memmap):
arr = Density.from_file(arr.data.filename, subset=arr_slice).data
else:
arr = np.asarray(arr.data[*arr_mesh])
else:
if type(arr) == np.memmap:
if isinstance(arr, np.memmap):
arr = np.memmap(
arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
)
Expand Down
4 changes: 2 additions & 2 deletions tme/matching_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,8 @@ def mcc_scoring(
tol = 1e3 * eps * be.max(be.abs(temp2), axis=axes, keepdims=True)

temp2[temp2 < tol] = 1
be.divide(numerator, temp2, out=temp)
be.clip(temp, a_min=-1, a_max=1, out=temp)
temp = be.divide(numerator, temp2, out=temp)
temp = be.clip(temp, a_min=-1, a_max=1, out=temp)

# Apply overlap ratio threshold
number_px_threshold = overlap_ratio * be.max(
Expand Down
2 changes: 1 addition & 1 deletion tme/matching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def memmap_to_array(arr: NDArray) -> NDArray:
obj:`numpy.ndarray`
In-memory version of ``arr``.
"""
if type(arr) == np.memmap:
if isinstance(arr, np.memmap):
memmap_filepath = arr.filename
arr = np.array(arr)
os.remove(memmap_filepath)
Expand Down
9 changes: 2 additions & 7 deletions tme/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@

class MatchingMemoryUsage(ABC):
"""
Base class for estimating the memory usage of template matching.
This class provides a template for estimating memory usage for
different matching methods. Users should subclass it and implement the
`base_usage` and `per_fork` methods to specify custom memory usage
estimates.
Class specification for estimating the memory requirements of template matching.
Parameters
----------
Expand Down Expand Up @@ -80,7 +75,7 @@ def per_fork(self) -> int:

class CCMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation for the CC fitter.
Memory usage estimation for CC scoring.
See Also
--------
Expand Down
6 changes: 3 additions & 3 deletions tme/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ class Orientations:
Array with additional orientation details (n, ).
"""

#: Array with translations of each orientation (n x d).
#: Array with translations of each orientation (n, d).
translations: np.ndarray

#: Array with zyx euler angles of each orientation (n x d).
#: Array with zyx euler angles of each orientation (n, d).
rotations: np.ndarray

#: Array with scores of each orientation (n, ).
Expand Down Expand Up @@ -158,7 +158,7 @@ def to_file(self, filename: str, file_format: type = None, **kwargs) -> None:
the file_format from the typical extension. Supported formats are
+---------------+----------------------------------------------------+
| text | pyTME's standard tab-separated orientations file |
| text | pytme's standard tab-separated orientations file |
+---------------+----------------------------------------------------+
| relion | Creates a STAR file of orientations |
+---------------+----------------------------------------------------+
Expand Down
7 changes: 3 additions & 4 deletions tme/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def parse_input(self, lines: List[str]) -> Dict:

class PDBParser(Parser):
"""
A Parser subclass for converting PDB file data into a dictionary representation.
This class is specifically designed to work with PDB file format.
Convert PDB file data into a dictionary representation [1]_.
References
----------
Expand Down Expand Up @@ -228,8 +227,8 @@ def parse_input(self, lines: List[str]) -> Dict:

class MMCIFParser(Parser):
"""
A Parser subclass for converting MMCIF file data into a dictionary representation.
This implementation heavily relies on the atomium library [1]_.
Convert MMCIF file data into a dictionary representation. This implementation
heavily relies on the atomium library [1]_.
References
----------
Expand Down
3 changes: 1 addition & 2 deletions tme/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from typing import Tuple, List

import numpy as np
from numpy.typing import NDArray

from ..types import BackendArray
from ..backends import backend as be
from ..backends import NumpyFFTWBackend
from ..types import BackendArray, NDArray
from ..matching_utils import euler_to_rotationmatrix


Expand Down
3 changes: 1 addition & 2 deletions tme/preprocessing/tilt_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

from .. import Preprocessor
from ..types import NDArray
from ..backends import backend as be
from ..matching_utils import euler_to_rotationmatrix

from ._utils import (
frequency_grid_at_angle,
compute_tilt_shape,
Expand Down
2 changes: 1 addition & 1 deletion tme/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def _hlpf_fitness(
orig = int((f_mask.size - 1) / 2)
dist = np.arange(-orig, orig + 1) * T
t, c, k = splrep(x=dist, y=f_mask, k=3)
i_max = np.ceil(np.divide(f_mask.shape, M))
i_max = np.ceil(np.divide(f_mask.shape, M)).astype(int)[0]
coarse_mask = np.arange(-i_max, i_max + 1) * M
spline = BSpline(t, c, k)
coarse_values = spline(coarse_mask)
Expand Down
4 changes: 2 additions & 2 deletions tme/tests/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import numpy as np

from tme.backends import backend
from tme.backends import backend as be
from tme.analyzer import (
MaxScoreOverRotations,
PeakCaller,
Expand Down Expand Up @@ -162,7 +162,7 @@ def test__iter__(self, use_memmap: bool):
score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
res = tuple(score_analyzer)
assert np.allclose(res[0].shape, self.data.shape)
assert res[0].dtype == backend._float_dtype
assert res[0].dtype == be._float_dtype
assert res[1].size == self.data.ndim
assert np.allclose(res[2].shape, self.data.shape)
assert len(res) == 4
Expand Down
4 changes: 2 additions & 2 deletions tme/tests/test_density.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from os import remove
from tempfile import mkstemp
from itertools import permutations
from os import remove

import pytest
import numpy as np
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_repr(self):
density = Density(data, origin, sampling_rate)
repr_str = density.__repr__()

response = "Density object at {}\nOrigin: {}, sampling_rate: {}, Shape: {}"
response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
response = response.format(
hex(id(density)),
tuple(np.round(density.origin, 3)),
Expand Down
Loading

0 comments on commit 928acd7

Please sign in to comment.