From 9d3ae9419873502598841a7528a67907996a5857 Mon Sep 17 00:00:00 2001 From: Andrew Herzing Date: Fri, 15 Mar 2024 16:27:27 -0400 Subject: [PATCH] Imporved maximum image based tilt alignment --- tomotools/align.py | 232 ++++++++++++++++----------------------------- 1 file changed, 83 insertions(+), 149 deletions(-) diff --git a/tomotools/align.py b/tomotools/align.py index c4479d24..bc6d82ea 100644 --- a/tomotools/align.py +++ b/tomotools/align.py @@ -16,6 +16,10 @@ import logging # from numpy.fft import fft, fftshift, ifftshift, ifft from skimage.registration import phase_cross_correlation as pcc +from skimage.transform import hough_line, hough_line_peaks +from skimage.feature import canny +from skimage.filters import sobel +import matplotlib.pylab as plt logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -37,16 +41,16 @@ def get_best_slices(stack, nslices): Returns ---------- - locs : NumPy array + slice_locations : NumPy array Location along the x-axis of the best slices """ total_mass = stack.data.sum((0, 1)) - mass_var = stack.data.sum(1).std(0) - mass_var[mass_var == 0] = 1e-5 - ratio = total_mass / mass_var - locs = ratio.argsort()[::-1][0:nslices] - return locs + mass_std = stack.data.sum(1).std(0) + mass_std[mass_std == 0] = 1e-5 + mass_ratio = total_mass / mass_std + slice_locations = mass_ratio.argsort()[::-1][0:nslices] + return slice_locations def get_coms(stack, slices): @@ -67,11 +71,11 @@ def get_coms(stack, slices): """ sinos = stack.data[:, :, slices] - y = np.linspace( - -int(sinos.shape[1] / 2), int(sinos.shape[1] / 2), sinos.shape[1], dtype="int" - ) + y_coordinates = np.linspace(sinos.shape[1] // 2, + sinos.shape[1] // 2, + sinos.shape[1], dtype="int") total_mass = sinos.sum(1) - coms = np.sum(np.transpose(sinos, [0, 2, 1]) * y, 2) / total_mass + coms = np.sum(np.transpose(sinos, [0, 2, 1]) * y_coordinates, 2) / total_mass return coms @@ -99,7 +103,7 @@ def apply_shifts(stack, shifts): "Number of shifts (%s) is not consistent with number" "of images in the stack (%s)" % (len(shifts), stack.data.shape[0]) ) - for i in range(0, shifted.data.shape[0]): + for i in range(shifted.data.shape[0]): shifted.data[i, :, :] = ndimage.shift( shifted.data[i, :, :], shift=[shifts[i, 0], shifts[i, 1]] ) @@ -296,7 +300,7 @@ def calculate_shifts_com(stack, nslices): angles = stack.metadata.Tomography.tilts [ntilts, ydim, xdim] = stack.data.shape - thetas = angles * np.pi / 180 + thetas = np.pi * angles / 180 coms = get_coms(stack, slices) I_tilts = np.eye(ntilts) @@ -536,8 +540,10 @@ def tilt_com(stack, slices=None, nslices=None): ---------- stack : TomoStack object 3-D numpy array containing the tilt series data - locs : list + slices : list Locations at which to perform the CoM analysis + nslices : int + Nubmer of slices to suer for the analysis Returns ---------- @@ -553,53 +559,41 @@ def com_motion(theta, r, x0, z0): def fit_line(x, m, b): return m * x + b + _, ny, nx = stack.data.shape + if stack.metadata.Tomography.tilts is None: - raise ValueError("No tilts in stack.metadata.Tomography.") + raise ValueError("Tilts are not defined in stack.metadata.Tomography.") - if stack.data.shape[2] < 3: - raise ValueError( - "Dataset is only %s pixels in x dimension. This method cannot be used." - ) + if nx < 3: + raise ValueError("Dataset is only %s pixels in x dimension. This method cannot be used." % stack.data.shape[2]) - nx = stack.data.shape[2] + if nslices > nx: + raise ValueError("nslices is greater than the X-dimension of the data.") + + # Determine the best slice locations for the analysis if slices is None: if nslices is None: - nx = stack.data.shape[2] - nslices = int(0.1 * nx) - if nslices < 3: - nslices = 3 - elif nslices > 50: - nslices = 50 + nslices = min(int(0.1 * nx), 20) else: - if nslices > nx: - raise ValueError( - "nslices is greater than the X-dimension of the data.") - if nslices > 0.3 * nx: - nslices = int(0.3 * nx) - logger.warning( - "nslices is greater than 30%% of number of x pixels. Using %s slices instead." - % nslices - ) + nslices = min(nslices, int(0.3 * nx)) + logger.warning("nslices is greater than 30%% of number of x pixels. Using %s slices instead." % nslices) + if nslices < 3: + nslices = 3 slices = get_best_slices(stack, nslices) logger.info("Performing alignments using best %s slices" % nslices) - else: - slices = np.sort(slices) - coms = get_coms(stack, slices) + slices = np.sort(slices) + coms = get_coms(stack, slices) thetas = np.pi * stack.metadata.Tomography.tilts / 180.0 - r = np.zeros(len(slices)) - x0 = np.zeros(len(slices)) - z0 = np.zeros(len(slices)) - - for i in range(0, len(slices)): - r[i], x0[i], z0[i] = optimize.curve_fit( - com_motion, xdata=thetas, ydata=coms[:, i], p0=[0, 0, 0] - )[0] - slope, intercept = optimize.curve_fit( - fit_line, xdata=r, ydata=slices, p0=[0, 0])[0] - tilt_shift = (stack.data.shape[1] / 2 - intercept) / slope + + r, x0, z0 = np.zeros(len(slices)), np.zeros(len(slices)), np.zeros(len(slices)) + + for idx, i in enumerate(slices): + r[idx], x0[idx], z0[idx] = optimize.curve_fit(com_motion, xdata=thetas, ydata=coms[:, i], p0=[0, 0, 0])[0] + slope, intercept = optimize.curve_fit(fit_line, xdata=r, ydata=slices, p0=[0, 0])[0] + tilt_shift = (ny / 2 - intercept) / slope tilt_rotation = -(180 * np.arctan(1 / slope) / np.pi) final = stack.trans_stack(yshift=tilt_shift, angle=tilt_rotation) @@ -612,122 +606,62 @@ def fit_line(x, m, b): return final -def tilt_maximage(data, limit=10, delta=0.3, show_progressbar=False): +def tilt_maximage(stack, limit=10, delta=0.1, plot_results=False): """ Perform automated determination of the tilt axis of a TomoStack. - The projected maximum image by is rotated positively and negatively, - filtered using a Hamming window, and the rotation angle is determined by - iterative histogram analysis + The projected maximum image used to determine the tilt axis by a + combination of Sobel filtering and Hough transform analysis. Args ---------- - data : TomoStack object + stack : TomoStack object 3-D numpy array containing the tilt series data limit : integer or float - Maximum rotation angle to use for MaxImage calculation + Maximum rotation angle to use for calculation delta : float - Angular increment for MaxImage calculation - show_progressbar : boolean - Enable/disable progress bar + Angular increment for calculation + plot_results : boolean + If True, plot the maximum image along with the lines determined + by Hough analysis Returns ---------- - opt_angle : TomoStack object - Calculated rotation to set the tilt axis vertical + rotated : TomoStack object + Rotated version of the input stack """ + image = stack.data.max(0) - def hamming(img): - """ - Apply hamming window to the image to remove edge effects. - - Args - ---------- - img : Numpy array - Input image - Returns - ---------- - out : Numpy array - Filtered image - - """ - # if img.shape[0] < img.shape[1]: - # center_loc = np.int32((img.shape[1] - img.shape[0]) / 2) - # img = img[:, center_loc:-center_loc] - # if img.shape[0] != img.shape[1]: - # img = img[:, 0:-1] - # h = np.hamming(img.shape[0]) - # ham2d = np.sqrt(np.outer(h, h)) - # elif img.shape[1] < img.shape[0]: - # center_loc = np.int32((img.shape[0] - img.shape[1]) / 2) - # img = img[center_loc:-center_loc, :] - # if img.shape[0] != img.shape[1]: - # img = img[0:-1, :] - # h = np.hamming(img.shape[1]) - # ham2d = np.sqrt(np.outer(h, h)) - # else: - h = np.hamming(img.shape[0]) - ham2d = np.sqrt(np.outer(h, h)) - out = ham2d * img - return out - - def find_score(im, angle): - """ - Perform histogram analysis to measure the rotation angle. - - Args - ---------- - im : Numpy array - Input image - angle : float - Angle by which to rotate the input image before analysis - - Returns - ---------- - hist : Numpy array - Result of integrating image along the vertical axis - score : numpy array - Score calculated from hist - - """ - im = ndimage.rotate(im, angle, reshape=False, order=3) - hist = np.sum(im, axis=1) - score = np.sum((hist[1:] - hist[:-1]) ** 2) - return hist, score - - image = np.max(data.data, 0) - - if image.shape[0] != image.shape[1]: - raise ValueError( - "Invalid data shape. Currently only square signal dimensions are supported." - ) - rot_pos = ndimage.rotate(hamming(image), -limit / 2, reshape=False, order=3) - rot_neg = ndimage.rotate(hamming(image), limit / 2, reshape=False, order=3) - angles = np.arange(-limit, limit + delta, delta) - scores_pos = [] - scores_neg = [] - for rotation_angle in tqdm.tqdm(angles, disable=(not show_progressbar)): - hist_pos, score_pos = find_score(rot_pos, rotation_angle) - hist_neg, score_neg = find_score(rot_neg, rotation_angle) - scores_pos.append(score_pos) - scores_neg.append(score_neg) - - best_score_pos = max(scores_pos) - best_score_neg = max(scores_neg) - pos_angle = -angles[scores_pos.index(best_score_pos)] - neg_angle = -angles[scores_neg.index(best_score_neg)] - opt_angle = (pos_angle + neg_angle) / 2 - - logger.info("Optimum positive rotation angle: {}".format(pos_angle)) - logger.info("Optimum negative rotation angle: {}".format(neg_angle)) - logger.info("Optimum positive rotation angle: {}".format(opt_angle)) - - out = copy.deepcopy(data) - out = out.trans_stack(xshift=0, yshift=0, angle=opt_angle) - out.data = np.transpose(out.data, (0, 2, 1)) - out.metadata.Tomography.tiltaxis = opt_angle - return out + edges = sobel(image) + + # Apply Canny edge detector for further edge enhancement + edges = canny(edges) + + # Perform Hough transform to detect lines + angles = np.pi * np.arange(-limit, limit, delta) / 180. + h, theta, d = hough_line(edges, angles) + + # Find peaks in Hough space + _, angles, dists = hough_line_peaks(h, theta, d, num_peaks=5) + + # Calculate average angle from detected lines + rotation_angle = np.degrees(np.mean(angles)) + print(rotation_angle) + + if plot_results: + fig, ax = plt.subplots(1) + ax.imshow(image, cmap='gray') + + for i in range(len(angles)): + (x0, y0) = dists[i] * np.array([np.cos(angles[i]), np.sin(angles[i])]) + ax.axline((x0, y0), slope=np.tan(angles[i] + np.pi / 2)) + + plt.tight_layout() + + rotated = stack.trans_stack(angle=-rotation_angle) + rotated.metadata.Tomography.tiltaxis = -rotation_angle + return rotated def align_to_other(stack, other):