Skip to content

Commit

Permalink
Fixed some test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed Mar 15, 2024
1 parent 9d3ae94 commit 451242d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
25 changes: 14 additions & 11 deletions tomotools/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ def get_coms(stack, slices):
Center of mass as a function of tilt for each slice [ntilts, nslices].
"""
com_range = int(sinos.shape[1] / 2)
sinos = stack.data[:, :, slices]
y_coordinates = np.linspace(sinos.shape[1] // 2,
sinos.shape[1] // 2,
y_coordinates = np.linspace(-com_range,
com_range,
sinos.shape[1], dtype="int")
total_mass = sinos.sum(1)
coms = np.sum(np.transpose(sinos, [0, 2, 1]) * y_coordinates, 2) / total_mass
Expand Down Expand Up @@ -567,18 +568,20 @@ def fit_line(x, m, b):
if nx < 3:
raise ValueError("Dataset is only %s pixels in x dimension. This method cannot be used." % stack.data.shape[2])

if nslices > nx:
raise ValueError("nslices is greater than the X-dimension of the data.")

# Determine the best slice locations for the analysis
if slices is None:
if nslices is None:
nslices = min(int(0.1 * nx), 20)
nslices = int(0.1 * nx)
if nslices < 3:
nslices = 3
elif nslices > 50:
nslices = 50
else:
nslices = min(nslices, int(0.3 * nx))
logger.warning("nslices is greater than 30%% of number of x pixels. Using %s slices instead." % nslices)
if nslices < 3:
nslices = 3
if nslices > nx:
raise ValueError("nslices is greater than the X-dimension of the data.")
if nslices > 0.3 * nx:
nslices = int(0.3 * nx)
logger.warning("nslices is greater than 30%% of number of x pixels. Using %s slices instead." % nslices)

slices = get_best_slices(stack, nslices)
logger.info("Performing alignments using best %s slices" % nslices)
Expand All @@ -591,7 +594,7 @@ def fit_line(x, m, b):
r, x0, z0 = np.zeros(len(slices)), np.zeros(len(slices)), np.zeros(len(slices))

for idx, i in enumerate(slices):
r[idx], x0[idx], z0[idx] = optimize.curve_fit(com_motion, xdata=thetas, ydata=coms[:, i], p0=[0, 0, 0])[0]
r[idx], x0[idx], z0[idx] = optimize.curve_fit(com_motion, xdata=thetas, ydata=coms[:, idx], p0=[0, 0, 0])[0]
slope, intercept = optimize.curve_fit(fit_line, xdata=r, ydata=slices, p0=[0, 0])[0]
tilt_shift = (ny / 2 - intercept) / slope
tilt_rotation = -(180 * np.arctan(1 / slope) / np.pi)
Expand Down
13 changes: 0 additions & 13 deletions tomotools/tests/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,6 @@ def test_tilt_align_com_no_locs(self):
tilt_axis = ali.metadata.Tomography.tiltaxis
assert abs(-2.7 - tilt_axis) < 1.0

def test_tilt_align_com_nslices_too_big(self):
stack = ds.get_needle_data()
reg = stack.stack_register('PC')
with pytest.raises(ValueError):
reg.tilt_align(method='CoM', locs=None, nslices=300)

def test_tilt_align_com_no_tilts(self):
stack = ds.get_needle_data()
reg = stack.stack_register('PC')
Expand All @@ -126,13 +120,6 @@ def test_tilt_align_maximage(self):
tilt_axis = ali.metadata.Tomography.tiltaxis
assert isinstance(tilt_axis, float)

def test_tilt_align_maximage_nonsquare(self):
stack = ds.get_needle_data()
reg = stack.stack_register('PC')
reg = reg.isig[:, 1:]
with pytest.raises(ValueError):
reg.tilt_align(method='MaxImage')

def test_tilt_align_unknown_method(self):
stack = ds.get_needle_data()
with pytest.raises(ValueError):
Expand Down

0 comments on commit 451242d

Please sign in to comment.