Skip to content

Commit

Permalink
fixed initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jdtuck committed Dec 6, 2024
1 parent a90025b commit ceee59d
Showing 1 changed file with 15 additions and 24 deletions.
39 changes: 15 additions & 24 deletions R/curve_karcher_mean.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,26 @@ curve_karcher_mean <- function(beta,
cent <- matrix(0, nrow = L, ncol = N)

for (n in 1:N) {
beta1 <- beta[ , , n]
beta1 <- beta[, , n]
centroid1 <- calculatecentroid(beta1)
cent[, n] <- -1 * centroid1
dim(centroid1) <- c(length(centroid1), 1)
beta1 <- beta1 - repmat(centroid1, 1, M)
beta[ , , n] <- beta1
beta[, , n] <- beta1
out <- curve_to_q(beta1, scale = TRUE)
q[, , n] <- out$q
len[n] <- out$len
len_q[n] <- out$lenq
}

mu <- q[, , 1]
bmu <- beta[, , 1]
# Initialize mu as one of the shapes
mnq <- rowMeans(q[1, , ])
dqq <- sqrt(colSums((q[1, , ] - matrix(
mnq, ncol = N, nrow = M
)) ^ 2))
min_ind <- which.min(dqq)
mu <- q[, , min_ind]
bmu <- beta[, , min_ind]
delta <- 0.5
tolv <- 1e-04
told <- 5 * 0.001
Expand All @@ -135,29 +141,13 @@ curve_karcher_mean <- function(beta,
v_d <- array(0, dim = c(L, M, N)) # include array to hold v_i / d_i
}

cli::cli_alert_info("Initializing...")
gam <- foreach::foreach(n = 1:N, .combine = cbind, .packages = "fdasrvf") %dopar% {
find_rotation_seed_unique(
q1 = mu,
q2 = q[, , n],
mode = mode,
rotation = rotated,
scale = TRUE,
lambda = lambda
)$gambest
}

gamI <- SqrtMeanInverse(gam)
bmu <- group_action_by_gamma_coord(bmu, gamI)
mu <- curve_to_q(bmu)$q
mu[is.nan(mu)] <- 0

while (itr < maxit) {
cli::cli_alert_info("Iteration {itr}/{maxit}...")

mu <- mu / sqrt(innerprod_q2(mu, mu))

if (mode == "C") basis <- find_basis_normal(mu)
if (mode == "C")
basis <- find_basis_normal(mu)

outfor <- foreach::foreach(n = 1:N, .combine = cbind, .packages='fdasrvf') %dopar% {
out <- karcher_calc(
Expand All @@ -184,7 +174,7 @@ curve_karcher_mean <- function(beta,
d_i <- unlist(outfor[4, ])
dim(d_i) <- N

sumd[itr + 1] = sumd[itr + 1] + sum(dist^2)
sumd[itr + 1] <- sumd[itr + 1] + sum(dist ^ 2)

if (ms == "median") {
# run for median only
Expand All @@ -207,7 +197,8 @@ curve_karcher_mean <- function(beta,

mu <- cos(delta * normvbar[itr]) * mu +
sin(delta * normvbar[itr]) * vbar / normvbar[itr]
if (mode == "C") mu <- project_curve(mu)
if (mode == "C")
mu <- project_curve(mu)
x <- q_to_curve(mu)
a <- -1 * calculatecentroid(x)
dim(a) <- c(length(a), 1)
Expand Down

0 comments on commit ceee59d

Please sign in to comment.