Skip to content

Commit

Permalink
add cython dist function
Browse files Browse the repository at this point in the history
  • Loading branch information
jdtuck committed Dec 27, 2023
1 parent b4d766c commit 4512640
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/crbfgs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,34 @@ def rlbfgs(np.ndarray[double, ndim=1, mode="c"] q1, np.ndarray[double, ndim=1, m
cdef np.ndarray[np.double_t,ndim=1] gam = vec_to_numpy(out, None)

return(gam)


def rlbfgs_dist(np.ndarray[double, ndim=2, mode="c"] q1, np.ndarray[double, ndim=1, mode="c"] q2):
d = np.zeros(q1.shape[1])
M = q1.shape[0]
alpha = 0.5
time = np.linspace(0, 1, M)
for i in range(q1.shape[1]):
gam = rlbfgs(q1[:,i], time, q2)
# warp q
gam_dev = np.gradient(gam, 1.0 / (M - 1))
tmp = np.interp((time[-1] - time[0]) * gam + time[0], time, q2)

qw = tmp * np.sqrt(gam_dev)
Dy = np.sqrt(np.trapz((qw - q1[:,i]) ** 2, time))

binsize = np.mean(np.diff(time))
psi = np.sqrt(np.gradient(gam, binsize))
q1dotq2 = np.trapz(psi, time)
if q1dotq2 > 1:
q1dotq2 = 1
elif q1dotq2 < -1:
q1dotq2 = -1

Dx = np.real(np.arccos(q1dotq2))

d[i] = alpha * Dy + (1-alpha) * Dx

return d


0 comments on commit 4512640

Please sign in to comment.