Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
jdtuck committed Dec 27, 2023
1 parent da26910 commit 20711b7
Showing 1 changed file with 26 additions and 30 deletions.
56 changes: 26 additions & 30 deletions src/crbfgs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,37 @@ def rlbfgs(np.ndarray[double, ndim=1, mode="c"] q1, np.ndarray[double, ndim=1, m
return(gam)


def rlbfgs_dist(np.ndarray[double, ndim=2, mode="c"] q1, np.ndarray[double, ndim=2, mode="c"] q2,
np.ndarray[double, ndim=2, mode="c"] idx):
def rlbfgs_dist(np.ndarray[double, ndim=2, mode="c"] q1, np.ndarray[double, ndim=1, mode="c"] q2):
q1 = np.ascontiguousarray(q1)
q2 = np.ascontiguousarray(q2)
d = np.zeros((q2.shape[1], idx.shape[1]))
d = np.zeros(q1.shape[1])
M = q1.shape[0]
N = q2.shape[1]
alpha = 0.5
time = np.linspace(0, 1, M)
for i in prange(N, nogil=True):
for j in range(idx.shape[1]):
q1t = q1[:,idx[i,j]]
q1t = np.ascontiguousarray(q1t)
q2t = q2[:, i]
q2t = np.ascontiguousarray(q2t)
gam = rlbfgs(q1t, time, q2t)
# warp q
gam_dev = np.gradient(gam, 1.0 / (M - 1))
tmp = np.interp((time[-1] - time[0]) * gam + time[0], time, q2t)

qw = tmp * np.sqrt(gam_dev)
Dy = np.sqrt(np.trapz((qw - q1t) ** 2, time))

binsize = np.mean(np.diff(time))
psi = np.sqrt(np.gradient(gam, binsize))
q1dotq2 = np.trapz(psi, time)
if q1dotq2 > 1:
q1dotq2 = 1
elif q1dotq2 < -1:
q1dotq2 = -1

Dx = np.real(np.arccos(q1dotq2))

d[i,j] = alpha * Dy + (1-alpha) * Dx

for i in range(q1.shape[1]):
q1t = q1[:,1]
q1t = np.ascontiguousarray(q1t)
gam = rlbfgs(q1t, time, q2)
# warp q
gam_dev = np.gradient(gam, 1.0 / (M - 1))
tmp = np.interp((time[-1] - time[0]) * gam + time[0], time, q2)

qw = tmp * np.sqrt(gam_dev)
Dy = np.sqrt(np.trapz((qw - q1t) ** 2, time))

binsize = np.mean(np.diff(time))
psi = np.sqrt(np.gradient(gam, binsize))
q1dotq2 = np.trapz(psi, time)
if q1dotq2 > 1:
q1dotq2 = 1
elif q1dotq2 < -1:
q1dotq2 = -1

Dx = np.real(np.arccos(q1dotq2))

d[i] = alpha * Dy + (1-alpha) * Dx

return d



0 comments on commit 20711b7

Please sign in to comment.