Skip to content

Commit

Permalink
2D averages of picked positions (#46)
Browse files Browse the repository at this point in the history
* utils to get closest normal vector

* utils for 2D rotational averages and clustering

* 2D average example notebook

* remove redundant cells

* restructure notebook

* feature extraction and tsne

* feature extraction and tsne

* add clustering based on extracted features

* remove cross-correlation functionalities
LorenzLamm authored May 11, 2024

Verified

This commit was signed with the committer’s verified signature.
Girgias Gina Peter Banyard
1 parent 16ec565 commit 14d05b3
Showing 3 changed files with 809 additions and 0 deletions.
554 changes: 554 additions & 0 deletions examples/example_usecase_2D_averages.ipynb

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions src/surforama/utils/geometry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union

import numpy as np
import trimesh
from scipy.spatial.transform import Rotation as R


@@ -45,3 +46,57 @@ def rotate_around_vector(
rotation_object = R.from_rotvec(rotation_vector)

return rotation_object.apply(to_rotate)


def find_closest_triangle(mesh: trimesh.Trimesh, point: np.ndarray) -> int:
"""
Find the index of the triangle in a mesh that is closest to a specified point.
Parameters
----------
mesh : trimesh.Trimesh
The mesh in which to find the closest triangle.
point : np.ndarray
A 3D point (as a NumPy array) for which the closest triangle is to be found.
Returns
-------
int
The index of the triangle that is closest to the given point.
Notes
-----
This function calculates the geometric center of each triangle and finds the one
closest to the specified point using Euclidean distance.
"""
triangle_centers = np.mean(mesh.vertices[mesh.faces], axis=1)
distances = np.linalg.norm(triangle_centers - point, axis=1)
return np.argmin(distances)


def find_closest_normal(
mesh: trimesh.Trimesh, point: np.ndarray
) -> np.ndarray:
"""
Find the normal of the closest triangle in a mesh to a specified point.
Parameters
----------
mesh : trimesh.Trimesh
The mesh from which to find the closest triangle normal.
point : np.ndarray
A 3D point (as a NumPy array) for which the closest triangle normal is to be found.
Returns
-------
np.ndarray
The normal vector of the closest triangle to the given point.
Notes
-----
This function uses `find_closest_triangle` to determine the closest triangle and then
retrieves the normal vector associated with that triangle from the mesh's `face_normals`.
"""
triangle_index = find_closest_triangle(mesh, point)
face_normals = mesh.face_normals
return face_normals[triangle_index]
200 changes: 200 additions & 0 deletions src/surforama/utils/twoD_averages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import numpy as np
import trimesh
from scipy.ndimage import map_coordinates

from surforama.utils.geometry import find_closest_normal


def define_rotation_kernel(shape: tuple) -> np.ndarray:
"""
Define a set of rotation kernels based on a specified shape.
Parameters
----------
shape : tuple
The shape (dimensions) for which the kernel is to be defined.
Returns
-------
np.ndarray
A 3D array containing rotation kernels for each radius.
"""
img_center = np.round(0.5 * np.asarray(shape)) - 1
rad_voxels = min(
int(shape[0] - img_center[0]), int(shape[1] - img_center[1])
) # number of radius rings

X, Y = np.meshgrid(
np.arange(shape[0]), np.arange(shape[1]), indexing="ij"
) # X, Y, meshgrid for x and y coordinates
kernels = np.zeros(
(rad_voxels, shape[0], shape[1]), dtype=float
) # kernels, rotation kernels

distances_to_center = np.sqrt(
(X - img_center[0] - 0.5) ** 2 + (Y - img_center[1] - 0.5) ** 2
)
for i in range(rad_voxels):
rad_ring = (
np.abs(distances_to_center - i) - 2
) # ring of radius i with thickness 2
rad_ring[rad_ring > 0] = (
0 # set everything further than two voxels from the ring to zero
)
kernels[i, :, :] = -0.5 * rad_ring

return kernels


def avg_vol_2D(vol: np.ndarray, mirror: bool = False) -> np.ndarray:
"""
Rotationally average an input volume to create a 2D image.
Parameters
----------
vol : np.ndarray
A 3D numpy array representing the volume.
mirror : bool, optional
Flag to mirror the averaged results horizontally.
Returns
-------
np.ndarray
A 2D image representing the averaged volume.
"""
shape = vol.shape
kernels = define_rotation_kernel(shape)

# Average
avg = np.zeros(
(min(vol.shape[0], vol.shape[1]), kernels.shape[0]), dtype=float
)

for i in range(vol.shape[2]):
hold = vol[:, :, i]
for j, kernel in enumerate(kernels):
avg[i, j] = (hold * kernel).sum() / (kernel.sum() + 1e-6)

if mirror:
avg = np.concatenate((avg[:, -1:0:-1], avg), axis=1)
return avg


def extract_normal_volume(
point: np.ndarray, normal: np.ndarray, tomogram: np.ndarray, shape: tuple
) -> np.ndarray:
"""
Extract a volume from a tomogram that is aligned with a point and a normal vector.
Parameters
----------
point : np.ndarray
A 3D point in the tomogram.
normal : np.ndarray
The normal vector.
tomogram : np.ndarray
A 3D numpy array of the tomogram.
shape : tuple
The desired shape of the extracted volume.
Returns
-------
np.ndarray
A 3D numpy array representing the extracted volume aligned with the normal vector.
"""
# Normalize the normal vector
z_axis = normal / np.linalg.norm(normal)

# Arbitrary choice for X-axis (just make sure it's not parallel to Z)
x_axis = (
np.array([1, 0, 0])
if z_axis[0] == 0 or z_axis[1] == 0
else np.array([0, 0, 1])
)
x_axis = (
x_axis - np.dot(x_axis, z_axis) * z_axis
) # Remove the component parallel to Z
x_axis /= np.linalg.norm(x_axis) # Normalize
# Y-axis to complete the right-handed system
y_axis = np.cross(z_axis, x_axis)

# Create rotation matrix from the original axes to the new axes
rotation_matrix = np.array([x_axis, y_axis, z_axis])

# Define the local coordinates around the point
local_x = np.linspace(-shape[0] / 2, shape[0] / 2, shape[0])
local_y = np.linspace(-shape[1] / 2, shape[1] / 2, shape[1])
local_z = np.linspace(-shape[2] / 2, shape[2] / 2, shape[2])
local_grid = np.array(
np.meshgrid(local_x, local_y, local_z, indexing="ij")
)

# Flatten and rotate the grid, then add the point
local_grid_flat = local_grid.reshape(3, -1)
global_grid_flat = np.dot(rotation_matrix.T, local_grid_flat).T + point

# Use scipy's map_coordinates to extract the aligned volume
extracted_volume = map_coordinates(
tomogram, global_grid_flat.T, order=1, mode="nearest"
).reshape(shape)

return extracted_volume


def create_2D_averages(
positions: list,
mesh: trimesh.Trimesh,
tomogram: np.ndarray,
shape: tuple = (20, 20, 20),
mirror: bool = True,
) -> np.ndarray:
"""
Create 2D averages from a list of positions using a given mesh and tomogram.
Parameters
----------
positions : list
A list of 3D points.
mesh :
A mesh object used to find normals corresponding to the given positions.
tomogram : np.ndarray
A 3D numpy array containing the tomogram data.
mirror : bool, optional
Flag to mirror the results horizontally for each 2D average.
Returns
-------
np.ndarray
An array of 2D averages calculated for each position.
"""
averages = []
for position in positions:
normal = find_closest_normal(mesh, position)
volume = extract_normal_volume(position, normal, tomogram, shape)
avg = avg_vol_2D(volume, mirror=mirror)
averages.append(avg)
return np.array(averages)


def normalize_averages(avgs: np.ndarray) -> np.ndarray:
"""
Normalize the average images.
This is done by subtracting the mean and dividing by the standard deviation.
Mean and standard deviation are calculated across all the averages.
Parameters
----------
avgs : np.ndarray
The array of averages to be normalized.
Returns
-------
np.ndarray
The normalized averages.
"""
mean = np.mean(avgs)
std = np.std(avgs)
avgs = (avgs - mean) / std
return avgs

0 comments on commit 14d05b3

Please sign in to comment.