Skip to content

Commit

Permalink
Imporved maximum image based tilt alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed Mar 15, 2024
1 parent 546d23f commit 9d3ae94
Showing 1 changed file with 83 additions and 149 deletions.
232 changes: 83 additions & 149 deletions tomotools/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import logging
# from numpy.fft import fft, fftshift, ifftshift, ifft
from skimage.registration import phase_cross_correlation as pcc
from skimage.transform import hough_line, hough_line_peaks
from skimage.feature import canny
from skimage.filters import sobel
import matplotlib.pylab as plt

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -37,16 +41,16 @@ def get_best_slices(stack, nslices):
Returns
----------
locs : NumPy array
slice_locations : NumPy array
Location along the x-axis of the best slices
"""
total_mass = stack.data.sum((0, 1))
mass_var = stack.data.sum(1).std(0)
mass_var[mass_var == 0] = 1e-5
ratio = total_mass / mass_var
locs = ratio.argsort()[::-1][0:nslices]
return locs
mass_std = stack.data.sum(1).std(0)
mass_std[mass_std == 0] = 1e-5
mass_ratio = total_mass / mass_std
slice_locations = mass_ratio.argsort()[::-1][0:nslices]
return slice_locations


def get_coms(stack, slices):
Expand All @@ -67,11 +71,11 @@ def get_coms(stack, slices):
"""
sinos = stack.data[:, :, slices]
y = np.linspace(
-int(sinos.shape[1] / 2), int(sinos.shape[1] / 2), sinos.shape[1], dtype="int"
)
y_coordinates = np.linspace(sinos.shape[1] // 2,
sinos.shape[1] // 2,
sinos.shape[1], dtype="int")
total_mass = sinos.sum(1)
coms = np.sum(np.transpose(sinos, [0, 2, 1]) * y, 2) / total_mass
coms = np.sum(np.transpose(sinos, [0, 2, 1]) * y_coordinates, 2) / total_mass
return coms


Expand Down Expand Up @@ -99,7 +103,7 @@ def apply_shifts(stack, shifts):
"Number of shifts (%s) is not consistent with number"
"of images in the stack (%s)" % (len(shifts), stack.data.shape[0])
)
for i in range(0, shifted.data.shape[0]):
for i in range(shifted.data.shape[0]):
shifted.data[i, :, :] = ndimage.shift(
shifted.data[i, :, :], shift=[shifts[i, 0], shifts[i, 1]]
)
Expand Down Expand Up @@ -296,7 +300,7 @@ def calculate_shifts_com(stack, nslices):

angles = stack.metadata.Tomography.tilts
[ntilts, ydim, xdim] = stack.data.shape
thetas = angles * np.pi / 180
thetas = np.pi * angles / 180

coms = get_coms(stack, slices)
I_tilts = np.eye(ntilts)
Expand Down Expand Up @@ -536,8 +540,10 @@ def tilt_com(stack, slices=None, nslices=None):
----------
stack : TomoStack object
3-D numpy array containing the tilt series data
locs : list
slices : list
Locations at which to perform the CoM analysis
nslices : int
Nubmer of slices to suer for the analysis
Returns
----------
Expand All @@ -553,53 +559,41 @@ def com_motion(theta, r, x0, z0):
def fit_line(x, m, b):
return m * x + b

_, ny, nx = stack.data.shape

if stack.metadata.Tomography.tilts is None:
raise ValueError("No tilts in stack.metadata.Tomography.")
raise ValueError("Tilts are not defined in stack.metadata.Tomography.")

if stack.data.shape[2] < 3:
raise ValueError(
"Dataset is only %s pixels in x dimension. This method cannot be used."
)
if nx < 3:
raise ValueError("Dataset is only %s pixels in x dimension. This method cannot be used." % stack.data.shape[2])

nx = stack.data.shape[2]
if nslices > nx:
raise ValueError("nslices is greater than the X-dimension of the data.")

# Determine the best slice locations for the analysis
if slices is None:
if nslices is None:
nx = stack.data.shape[2]
nslices = int(0.1 * nx)
if nslices < 3:
nslices = 3
elif nslices > 50:
nslices = 50
nslices = min(int(0.1 * nx), 20)
else:
if nslices > nx:
raise ValueError(
"nslices is greater than the X-dimension of the data.")
if nslices > 0.3 * nx:
nslices = int(0.3 * nx)
logger.warning(
"nslices is greater than 30%% of number of x pixels. Using %s slices instead."
% nslices
)
nslices = min(nslices, int(0.3 * nx))
logger.warning("nslices is greater than 30%% of number of x pixels. Using %s slices instead." % nslices)
if nslices < 3:
nslices = 3

slices = get_best_slices(stack, nslices)
logger.info("Performing alignments using best %s slices" % nslices)
else:
slices = np.sort(slices)

coms = get_coms(stack, slices)
slices = np.sort(slices)

coms = get_coms(stack, slices)
thetas = np.pi * stack.metadata.Tomography.tilts / 180.0
r = np.zeros(len(slices))
x0 = np.zeros(len(slices))
z0 = np.zeros(len(slices))

for i in range(0, len(slices)):
r[i], x0[i], z0[i] = optimize.curve_fit(
com_motion, xdata=thetas, ydata=coms[:, i], p0=[0, 0, 0]
)[0]
slope, intercept = optimize.curve_fit(
fit_line, xdata=r, ydata=slices, p0=[0, 0])[0]
tilt_shift = (stack.data.shape[1] / 2 - intercept) / slope

r, x0, z0 = np.zeros(len(slices)), np.zeros(len(slices)), np.zeros(len(slices))

for idx, i in enumerate(slices):
r[idx], x0[idx], z0[idx] = optimize.curve_fit(com_motion, xdata=thetas, ydata=coms[:, i], p0=[0, 0, 0])[0]
slope, intercept = optimize.curve_fit(fit_line, xdata=r, ydata=slices, p0=[0, 0])[0]
tilt_shift = (ny / 2 - intercept) / slope
tilt_rotation = -(180 * np.arctan(1 / slope) / np.pi)

final = stack.trans_stack(yshift=tilt_shift, angle=tilt_rotation)
Expand All @@ -612,122 +606,62 @@ def fit_line(x, m, b):
return final


def tilt_maximage(data, limit=10, delta=0.3, show_progressbar=False):
def tilt_maximage(stack, limit=10, delta=0.1, plot_results=False):
"""
Perform automated determination of the tilt axis of a TomoStack.
The projected maximum image by is rotated positively and negatively,
filtered using a Hamming window, and the rotation angle is determined by
iterative histogram analysis
The projected maximum image used to determine the tilt axis by a
combination of Sobel filtering and Hough transform analysis.
Args
----------
data : TomoStack object
stack : TomoStack object
3-D numpy array containing the tilt series data
limit : integer or float
Maximum rotation angle to use for MaxImage calculation
Maximum rotation angle to use for calculation
delta : float
Angular increment for MaxImage calculation
show_progressbar : boolean
Enable/disable progress bar
Angular increment for calculation
plot_results : boolean
If True, plot the maximum image along with the lines determined
by Hough analysis
Returns
----------
opt_angle : TomoStack object
Calculated rotation to set the tilt axis vertical
rotated : TomoStack object
Rotated version of the input stack
"""
image = stack.data.max(0)

def hamming(img):
"""
Apply hamming window to the image to remove edge effects.
Args
----------
img : Numpy array
Input image
Returns
----------
out : Numpy array
Filtered image
"""
# if img.shape[0] < img.shape[1]:
# center_loc = np.int32((img.shape[1] - img.shape[0]) / 2)
# img = img[:, center_loc:-center_loc]
# if img.shape[0] != img.shape[1]:
# img = img[:, 0:-1]
# h = np.hamming(img.shape[0])
# ham2d = np.sqrt(np.outer(h, h))
# elif img.shape[1] < img.shape[0]:
# center_loc = np.int32((img.shape[0] - img.shape[1]) / 2)
# img = img[center_loc:-center_loc, :]
# if img.shape[0] != img.shape[1]:
# img = img[0:-1, :]
# h = np.hamming(img.shape[1])
# ham2d = np.sqrt(np.outer(h, h))
# else:
h = np.hamming(img.shape[0])
ham2d = np.sqrt(np.outer(h, h))
out = ham2d * img
return out

def find_score(im, angle):
"""
Perform histogram analysis to measure the rotation angle.
Args
----------
im : Numpy array
Input image
angle : float
Angle by which to rotate the input image before analysis
Returns
----------
hist : Numpy array
Result of integrating image along the vertical axis
score : numpy array
Score calculated from hist
"""
im = ndimage.rotate(im, angle, reshape=False, order=3)
hist = np.sum(im, axis=1)
score = np.sum((hist[1:] - hist[:-1]) ** 2)
return hist, score

image = np.max(data.data, 0)

if image.shape[0] != image.shape[1]:
raise ValueError(
"Invalid data shape. Currently only square signal dimensions are supported."
)
rot_pos = ndimage.rotate(hamming(image), -limit / 2, reshape=False, order=3)
rot_neg = ndimage.rotate(hamming(image), limit / 2, reshape=False, order=3)
angles = np.arange(-limit, limit + delta, delta)
scores_pos = []
scores_neg = []
for rotation_angle in tqdm.tqdm(angles, disable=(not show_progressbar)):
hist_pos, score_pos = find_score(rot_pos, rotation_angle)
hist_neg, score_neg = find_score(rot_neg, rotation_angle)
scores_pos.append(score_pos)
scores_neg.append(score_neg)

best_score_pos = max(scores_pos)
best_score_neg = max(scores_neg)
pos_angle = -angles[scores_pos.index(best_score_pos)]
neg_angle = -angles[scores_neg.index(best_score_neg)]
opt_angle = (pos_angle + neg_angle) / 2

logger.info("Optimum positive rotation angle: {}".format(pos_angle))
logger.info("Optimum negative rotation angle: {}".format(neg_angle))
logger.info("Optimum positive rotation angle: {}".format(opt_angle))

out = copy.deepcopy(data)
out = out.trans_stack(xshift=0, yshift=0, angle=opt_angle)
out.data = np.transpose(out.data, (0, 2, 1))
out.metadata.Tomography.tiltaxis = opt_angle
return out
edges = sobel(image)

# Apply Canny edge detector for further edge enhancement
edges = canny(edges)

# Perform Hough transform to detect lines
angles = np.pi * np.arange(-limit, limit, delta) / 180.
h, theta, d = hough_line(edges, angles)

# Find peaks in Hough space
_, angles, dists = hough_line_peaks(h, theta, d, num_peaks=5)

# Calculate average angle from detected lines
rotation_angle = np.degrees(np.mean(angles))
print(rotation_angle)

if plot_results:
fig, ax = plt.subplots(1)
ax.imshow(image, cmap='gray')

for i in range(len(angles)):
(x0, y0) = dists[i] * np.array([np.cos(angles[i]), np.sin(angles[i])])
ax.axline((x0, y0), slope=np.tan(angles[i] + np.pi / 2))

plt.tight_layout()

rotated = stack.trans_stack(angle=-rotation_angle)
rotated.metadata.Tomography.tiltaxis = -rotation_angle
return rotated


def align_to_other(stack, other):
Expand Down

0 comments on commit 9d3ae94

Please sign in to comment.