Skip to content

Commit

Permalink
continue numba conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jdtuck committed Dec 17, 2023
1 parent 97688f5 commit a2ca34c
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 22 deletions.
6 changes: 3 additions & 3 deletions fdasrsf/rbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def solve(self, maxiter=30, verb=0, lam=0, penalty="roughness"):

return

def alignment_cost(self, h, q2k, lam=0, penalty="roughness"):
def alignment_cost(self, q2k, h, lam=0, penalty="roughness"):
r"""
Evaluate the cost function :math:`f = ||q1 - ((q2,hk),h)||^2`.
:math:`h=sqrt{\dot{\gamma}}` is a sequential update of cumulative
Expand Down Expand Up @@ -458,7 +458,7 @@ def linesearch_hint(self, d, f0, df0, q2k, lam=0, penalty="roughness"):

# Make the chosen step and compute cost there
newh = self.exp(hid, d, alpha)
newf = self.alignment_cost(newh, q2k, lam, penalty)
newf = self.alignment_cost(q2k, newh, lam, penalty)
cost_evaluations = 1

# backtrack while the Armijo criterion is not satisfied
Expand All @@ -472,7 +472,7 @@ def linesearch_hint(self, d, f0, df0, q2k, lam=0, penalty="roughness"):

# look closer down the line
newh = self.exp(hid, d, alpha)
newf = self.alignment_cost(newh, q2k, lam, penalty)
newf = self.alignment_cost(q2k, newh, lam, penalty)
cost_evaluations += 1
tst = newh <= 0

Expand Down
225 changes: 206 additions & 19 deletions fdasrsf/umap_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import numba
from fdasrsf.rbfgs import rlbfgs
from numba.core.typing import cffi_utils
from _DP import ffi, lib
import _DP
Expand All @@ -21,14 +20,18 @@
arange,
int64,
trapz,
cumtrapz,
cumsum,
ones,
)
from numpy import (
ascontiguousarray,
roll,
eye,
inf,
nan,
cos,
sin,
insert,
arccos,
)
from numpy import kron, floor, mean
Expand Down Expand Up @@ -72,17 +75,183 @@ def warp(q1, q2):

return gam


@numba.njit()
def strict_inc_func(t):
return 1e-4 * t


@numba.njit()
def cumtrapz(y, x):
dt = diff(x)/2.0
z = cumsum(dt*(y[0:-1]+y[1:]))
z = insert(z, 0, 0)

return z


@numba.njit()
def alignment_cost(q1, q2k, h):
q2new = group_action_SRVF(q2k, h)
f = normL2(q1-q2new)**2
return f


@numba.njit()
def group_action_SRVF(q, h):
p = q.shape[0]
t = linspace(0, 1, p)
gamma = zeros(p)
gamma[1:] = cumtrapz(h**2, t)
h = sqrt(grad(gamma, mean(diff(t))))
qnew = zeros(p)
qnew = interp(gamma, t, q) * h

return qnew


@numba.njit()
def normL2(f):
val = sqrt(innerProdL2(f, f))
return val


@numba.njit()
def innerProdL2(f1, f2):
tmp = f1 * f2
val = trapz(tmp, linspace(0,1,f1.shape[0]))
return val


@numba.njit()
def alignment_costgrad(q1, q2k, h):
T = q1.shape[0]
t = linspace(0, 1, T)
f = normL2(q1-q2k)**2
q2kdot = grad(q2k, 1/(T-1))
dq = q1-q2k
v = zeros(T)
tmp = dq * q2kdot
tmp1 = dq * q2k
v[1:] = 2 * cumtrapz(tmp, t)
v = v - tmp1
g = v - trapz(v, t)

return f, g


@numba.njit()
def norm2(v):
t = linspace(0, 1, v.shape[0])
val = sqrt(trapz(v**2, t))
return val


@numba.njit()
def stoppingcriterion(options, info, last):
stop = 0
stats = info[last]

if stats["gradnorm"] <= options["tolgradnorm"]:
stop = 2

if stats["time"] >= options["maxtime"]:
stop = 3

if stats["iter"] >= options["maxiter"]:
stop = 4

return stop


@numba.njit()
def inner(v1, v2):
return trapz(v1 * v2, linspace(0, 1, v1.shape[0]))


@numba.njit()
def getDirection(hCurGradient, sHistory, yHistory, rhoHistory, scaleFactor, j):
q = hCurGradient
inner_s_q = zeros(j)

for i in range(j, 0, -1):
inner_s_q[i - 1] = rhoHistory[i - 1] * inner(sHistory[i - 1], q)
q = q - inner_s_q[i - 1] * yHistory[i - 1]

r = scaleFactor * q

for i in range(0, j):
omega = rhoHistory[i] * inner(yHistory[i], r)
r = r + (inner_s_q[i] - omega) * sHistory[i]

direction = -r

return direction


@numba.njit()
def exp(f1, v, delta=1):
vd = delta * v
nrm_vd = norm2(vd)

if nrm_vd > 0:
f2 = f1 * cos(nrm_vd) + vd * (sin(nrm_vd) / nrm_vd)
else:
f2 = f1

return f2


@numba.njit()
def linesearch_hint(d, f0, df0, q2k):
T = q2k.shape[0]
contraction_factor = 0.5
suff_decr = 1e-6
max_ls_steps = 25
ls_backtrack = True
ls_force_decrease = True

alpha = 1

hid = ones(T)

newh = exp(hid, d, alpha)
newf = alignment_cost(q2k, newh)
cost_evaluations = 1

tst = newh <= 0
while ls_backtrack and (
(newf > (f0 + suff_decr * alpha * df0)) or (tst.sum() > 0)
):
# reduce the step size
alpha *= contraction_factor

newh = exp(hid, d, alpha)
newf = alignment_cost(q2k, newh)
cost_evaluations += 1
tst = newh <= 0

if cost_evaluations >= max_ls_steps:
break

if ls_force_decrease and newf > f0:
alpha = 0
newh = hid
newf = f0

norm_d = norm2(d)
stepsize = alpha * norm_d

lsstats = {"costevals": cost_evaluations, "stepsize": stepsize, "alpha": alpha}

return stepsize, newh, lsstats


@numba.njit()
def warp_rbfgs(q1, q2):
T = q1.shape[0]
time = linspace(0, 1, q1.shape[0])
maxiter = 30
lam = 0
penalty = "roughness"
tolgradnorm = 1e-3
maxtime = inf
# minimum norm of tangent vector that points from current to next
Expand Down Expand Up @@ -117,24 +286,37 @@ def warp_rbfgs(q1, q2):
stepsize = 1
accepted = True

hCurCost, hCurGradient = alignment_costgrad(q2tilde, htilde, lam, penalty)
hCurCost, hCurGradient = alignment_costgrad(q1, q2tilde, htilde)

hCurGradNorm = norm2(hCurGradient)

lsstats = {"costevals": 0, "stepsize": 0.0, "alpha": 0.0}

ultimatum = False

info = []
stats = {
"iter": k,
"cost": hCurCost,
"gradnorm": hCurGradNorm,
"stepsize": nan,
"accepted": None,
"linesearch": lsstats,
}
info.append(stats)

while True:
stop = stoppingcriterion(options, k)
stop = stoppingcriterion(options, info, k)

if stop == 0:
if stepsize < options["minstepsize"]:
if not ultimatum:
j = 0
ultimatum = True
else:
stop = 1
if stepsize < options["minstepsize"]:
if not ultimatum:
j = 0
ultimatum = True
else:
ultimatum = False
stop = 1
else:
ultimatum = False

if stop > 0:
break
Expand All @@ -149,18 +331,14 @@ def warp_rbfgs(q1, q2):
)

in_prod = inner(hCurGradient, p)
stepsize, hNext, lsstats = linesearch_hint(
p, hCurCost, in_prod, q2tilde, lam, penalty
)
stepsize, hNext, lsstats = linesearch_hint(p, hCurCost, in_prod, q2tilde)
htilde = group_action_SRVF(htilde, hNext)
q2tilde = group_action_SRVF(q2tilde, hNext)

alpha = stepsize / norm2(p)
step = alpha * p

hNextCost, hNextGradient = alignment_costgrad(
q2tilde, hNext, lam, penalty
)
hNextCost, hNextGradient = alignment_costgrad(q1, q2tilde, hNext)

sk = step
yk = hNextGradient - hCurGradient
Expand Down Expand Up @@ -210,6 +388,15 @@ def warp_rbfgs(q1, q2):
hCurCost = hNextCost

k += 1
stats = {
"iter": k,
"cost": hCurCost,
"gradnorm": hCurGradNorm,
"stepsize": nan,
"accepted": accepted,
"linesearch": lsstats,
}
info.append(stats)

gam = cumtrapz(htilde**2, time)

Expand Down

0 comments on commit a2ca34c

Please sign in to comment.