Skip to content

Commit

Permalink
[Issue-122] Pass shortest arg for interp; optionally enforce non-nega…
Browse files Browse the repository at this point in the history
…tive scalar … (#123)
  • Loading branch information
jcao-bdai authored May 13, 2024
1 parent bc6103a commit 1cf7d92
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 13 deletions.
6 changes: 6 additions & 0 deletions spatialmath/base/quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def r2q(
check: Optional[bool] = False,
tol: float = 20,
order: Optional[str] = "sxyz",
shortest: bool = False,
) -> UnitQuaternionArray:
"""
Convert SO(3) rotation matrix to unit-quaternion
Expand All @@ -562,6 +563,8 @@ def r2q(
:param order: the order of the returned quaternion elements. Must be 'sxyz' or
'xyzs'. Defaults to 'sxyz'.
:type order: str
:param shortest: ensures the quaternion has non-negative scalar part.
:type shortest: bool, default to False
:return: unit-quaternion as Euler parameters
:rtype: ndarray(4)
:raises ValueError: for non SO(3) argument
Expand Down Expand Up @@ -633,6 +636,9 @@ def r2q(
e[1] = math.copysign(e[1], R[0, 2] + R[2, 0])
e[2] = math.copysign(e[2], R[2, 1] + R[1, 2])

if shortest and e[0] < 0:
e = -e

if order == "sxyz":
return e
elif order == "xyzs":
Expand Down
12 changes: 9 additions & 3 deletions spatialmath/base/transforms2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,16 +853,16 @@ def tr2jac2(T: SE2Array) -> R3x3:


@overload
def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float) -> SO2Array:
def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True) -> SO2Array:
...


@overload
def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float) -> SE2Array:
def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True) -> SE2Array:
...


def trinterp2(start, end, s):
def trinterp2(start, end, s, shortest: bool = True):
"""
Interpolate SE(2) or SO(2) matrices
Expand All @@ -872,6 +872,8 @@ def trinterp2(start, end, s):
:type end: ndarray(3,3) or ndarray(2,2)
:param s: interpolation coefficient, range 0 to 1
:type s: float
:param shortest: take the shortest path along the great circle for the rotation
:type shortest: bool, default to True
:return: interpolated SE(2) or SO(2) matrix value
:rtype: ndarray(3,3) or ndarray(2,2)
:raises ValueError: bad arguments
Expand Down Expand Up @@ -917,6 +919,8 @@ def trinterp2(start, end, s):

th0 = math.atan2(start[1, 0], start[0, 0])
th1 = math.atan2(end[1, 0], end[0, 0])
if shortest:
th1 = th0 + smb.wrap_mpi_pi(th1 - th0)

th = th0 * (1 - s) + s * th1

Expand All @@ -937,6 +941,8 @@ def trinterp2(start, end, s):

th0 = math.atan2(start[1, 0], start[0, 0])
th1 = math.atan2(end[1, 0], end[0, 0])
if shortest:
th1 = th0 + smb.wrap_mpi_pi(th1 - th0)

p0 = transl2(start)
p1 = transl2(end)
Expand Down
16 changes: 9 additions & 7 deletions spatialmath/base/transforms3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,16 +1605,16 @@ def trnorm(T: SE3Array) -> SE3Array:


@overload
def trinterp(start: Optional[SO3Array], end: SO3Array, s: float) -> SO3Array:
def trinterp(start: Optional[SO3Array], end: SO3Array, s: float, shortest: bool = True) -> SO3Array:
...


@overload
def trinterp(start: Optional[SE3Array], end: SE3Array, s: float) -> SE3Array:
def trinterp(start: Optional[SE3Array], end: SE3Array, s: float, shortest: bool = True) -> SE3Array:
...


def trinterp(start, end, s):
def trinterp(start, end, s, shortest=True):
"""
Interpolate SE(3) matrices
Expand All @@ -1624,6 +1624,8 @@ def trinterp(start, end, s):
:type end: ndarray(4,4) or ndarray(3,3)
:param s: interpolation coefficient, range 0 to 1
:type s: float
:param shortest: take the shortest path along the great circle for the rotation
:type shortest: bool, default to True
:return: interpolated SE(3) or SO(3) matrix value
:rtype: ndarray(4,4) or ndarray(3,3)
:raises ValueError: bad arguments
Expand Down Expand Up @@ -1663,12 +1665,12 @@ def trinterp(start, end, s):
if start is None:
# TRINTERP(T, s)
q0 = r2q(end)
qr = qslerp(qeye(), q0, s)
qr = qslerp(qeye(), q0, s, shortest=shortest)
else:
# TRINTERP(T0, T1, s)
q0 = r2q(start)
q1 = r2q(end)
qr = qslerp(q0, q1, s)
qr = qslerp(q0, q1, s, shortest=shortest)

return q2r(qr)

Expand All @@ -1679,7 +1681,7 @@ def trinterp(start, end, s):
q0 = r2q(t2r(end))
p0 = transl(end)

qr = qslerp(qeye(), q0, s)
qr = qslerp(qeye(), q0, s, shortest=shortest)
pr = s * p0
else:
# TRINTERP(T0, T1, s)
Expand All @@ -1689,7 +1691,7 @@ def trinterp(start, end, s):
p0 = transl(start)
p1 = transl(end)

qr = qslerp(q0, q1, s)
qr = qslerp(q0, q1, s, shortest=shortest)
pr = p0 * (1 - s) + s * p1

return rt2tr(q2r(qr), pr)
Expand Down
8 changes: 5 additions & 3 deletions spatialmath/baseposematrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,16 @@ def log(self, twist: Optional[bool] = False) -> Union[NDArray, List[NDArray]]:
else:
return log

def interp(self, end: Optional[bool] = None, s: Union[int, float] = None) -> Self:
def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shortest: bool = True) -> Self:
"""
Interpolate between poses (superclass method)
:param end: final pose
:type end: same as ``self``
:param s: interpolation coefficient, range 0 to 1, or number of steps
:type s: array_like or int
:param shortest: take the shortest path along the great circle for the rotation
:type shortest: bool, default to True
:return: interpolated pose
:rtype: same as ``self``
Expand Down Expand Up @@ -432,13 +434,13 @@ def interp(self, end: Optional[bool] = None, s: Union[int, float] = None) -> Sel
if self.N == 2:
# SO(2) or SE(2)
return self.__class__(
[smb.trinterp2(start=self.A, end=end, s=_s) for _s in s]
[smb.trinterp2(start=self.A, end=end, s=_s, shortest=shortest) for _s in s]
)

elif self.N == 3:
# SO(3) or SE(3)
return self.__class__(
[smb.trinterp(start=self.A, end=end, s=_s) for _s in s]
[smb.trinterp(start=self.A, end=end, s=_s, shortest=shortest) for _s in s]
)

def interp1(self, s: float = None) -> Self:
Expand Down
7 changes: 7 additions & 0 deletions tests/base/test_quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def test_rotation(self):
)
nt.assert_array_almost_equal(qvmul([0, 1, 0, 0], [0, 0, 1]), np.r_[0, 0, -1])

large_rotation = math.pi + 0.01
q1 = r2q(tr.rotx(large_rotation), shortest=False)
q2 = r2q(tr.rotx(large_rotation), shortest=True)
self.assertLess(q1[0], 0)
self.assertGreater(q2[0], 0)
self.assertTrue(qisequal(q1=q1, q2=q2, unitq=True))

def test_slerp(self):
q1 = np.r_[0, 1, 0, 0]
q2 = np.r_[0, 0, 1, 0]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,16 @@ def test_interp(self):
array_compare(I.interp(TT, s=1), TT)
array_compare(I.interp(TT, s=0.5), SE2(1, -2, 0.3))

R1 = SO2(math.pi - 0.1)
R2 = SO2(-math.pi + 0.2)
array_compare(R1.interp(R2, s=0.5, shortest=False), SO2(0.05))
array_compare(R1.interp(R2, s=0.5, shortest=True), SO2(-math.pi + 0.05))

T1 = SE2(0, 0, math.pi - 0.1)
T2 = SE2(0, 0, -math.pi + 0.2)
array_compare(T1.interp(T2, s=0.5, shortest=False), SE2(0, 0, 0.05))
array_compare(T1.interp(T2, s=0.5, shortest=True), SE2(0, 0, -math.pi + 0.05))

def test_miscellany(self):
TT = SE2(1, 2, 0.3)

Expand Down

0 comments on commit 1cf7d92

Please sign in to comment.