Skip to content

Commit

Permalink
Updated use of weights in procrustes analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
Till Schnabel committed Oct 30, 2024
1 parent 6ed17ba commit c73170c
Showing 1 changed file with 43 additions and 44 deletions.
87 changes: 43 additions & 44 deletions trimesh/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,8 @@ def procrustes(
Finds the transformation T mapping a to b which minimizes the
square sum distances between Ta and b, also called the cost.
Optionally specify different weights for the points in a to minimize
the weighted square sum distances between Ta and b, which can
improve transformation robustness on noisy data if the points'
probability distribution is known.
Optionally filter the points in a and b via a binary weights array.
Non-uniform weights are also supported, but won't yield the optimal rotation.
Parameters
----------
Expand All @@ -221,7 +219,11 @@ def procrustes(
b : (n,3) float
List of points in space
weights : (n,) float
List of floats representing how much weight is assigned to each point of a
List of floats representing how much weight is assigned to each point.
Binary entries can be used to filter the arrays; normalization is not required.
Translation and scaling are adjusted according to the weighting.
Note, however, that this method does not yield the optimal rotation for non-uniform weighting,
as this would require an iterative, nonlinear optimization approach.
reflection : bool
If the transformation is allowed reflections
translation : bool
Expand All @@ -241,56 +243,51 @@ def procrustes(
The cost of the transformation
"""

a = np.asanyarray(a, dtype=np.float64)
b = np.asanyarray(b, dtype=np.float64)
if not util.is_shape(a, (-1, 3)) or not util.is_shape(b, (-1, 3)):
raise ValueError("points must be (n,3)!")
if len(a) != len(b):
raise ValueError("a and b must contain same number of points!")
if weights is not None:
w = np.asanyarray(weights, dtype=np.float64)
if len(w) != len(a):
raise ValueError("weights must have same length as a and b!")
w_norm = (w / w.sum()).reshape((-1, 1))
a_original = np.asanyarray(a, dtype=np.float64)
b_original = np.asanyarray(b, dtype=np.float64)
if not util.is_shape(a_original, (-1, 3)) or not util.is_shape(b_original, (-1, 3)):
raise ValueError('points must be (n,3)!')
if len(a_original) != len(b_original):
raise ValueError('a and b must contain same number of points!')
# weights are set to uniform if not provided.
if weights is None:
weights = np.ones(len(a_original))
w = np.maximum(np.asanyarray(weights, dtype=np.float64), 0)
if len(w) != len(a):
raise ValueError("weights must have same length as a and b!")
w_norm = (w / w.sum()).reshape((-1, 1))

# All zero entries are removed from further computations.
# If weights is a binary array, the optimal solution can still be found by simply removing the zero entries.
nonzero_weights = w_norm[:, 0] > 0
a_nonzero = a_original[nonzero_weights]
b_nonzero = b_original[nonzero_weights]
w_norm = w_norm[nonzero_weights]

# Remove translation component
if translation:
# acenter is a weighted average of the individual points.
if weights is None:
acenter = a.mean(axis=0)
else:
acenter = (a * w_norm).sum(axis=0)
bcenter = b.mean(axis=0)
# centers are (weighted) averages of the individual points.
acenter = (a_nonzero * w_norm).sum(axis=0)
bcenter = (b_nonzero * w_norm).sum(axis=0)
else:
acenter = np.zeros(a.shape[1])
bcenter = np.zeros(b.shape[1])
acenter = np.zeros(a_nonzero.shape[1])
bcenter = np.zeros(b_nonzero.shape[1])

# Remove scale component
if scale:
if weights is None:
ascale = np.sqrt(((a - acenter) ** 2).sum() / len(a))
# ascale is the square root of weighted average of the
# squared difference
# between each point and acenter.
else:
ascale = np.sqrt((((a - acenter) ** 2) * w_norm).sum())

bscale = np.sqrt(((b - bcenter) ** 2).sum() / len(b))
# scale is the square root of the (weighted) average of the
# squared difference between each point and the center.
ascale = np.sqrt((((a_nonzero - acenter)**2) * w_norm).sum())
bscale = np.sqrt((((b_nonzero - bcenter)**2) * w_norm).sum())
else:
ascale = 1
bscale = 1

# Use SVD to find optimal orthogonal matrix R
# constrained to det(R) = 1 if necessary.
# w_mat is multiplied with the centered and scaled a, such that the points
# can be weighted differently.

if weights is None:
target = np.dot(((b - bcenter) / bscale).T, ((a - acenter) / ascale))
else:
target = np.dot(
((b - bcenter) / bscale).T, ((a - acenter) / ascale) * w.reshape((-1, 1))
)
target = np.dot(((b_nonzero - bcenter) / bscale).T,
((a_nonzero - acenter) / ascale))

u, _s, vh = np.linalg.svd(target)

Expand All @@ -308,9 +305,11 @@ def procrustes(
matrix = np.vstack((matrix, np.array([0.0] * (a.shape[1]) + [1.0]).reshape(1, -1)))

if return_cost:
transformed = transform_points(a, matrix)
# return the mean euclidean distance squared as the cost
cost = ((b - transformed) ** 2).mean()
# Transform the original input array, including zero-weighted points
transformed = transform_points(a_original, matrix)
# The cost is the (weighted) sum of the euclidean distances between
# the transformed source points and the target points.
cost = (((b_nonzero - transformed[nonzero_weights])**2) * w_norm).sum()
return matrix, transformed, cost
else:
return matrix
Expand Down

0 comments on commit c73170c

Please sign in to comment.