Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental parametric tree solve WIP #1824

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CliqueStateMachine/services/CliqueStateMachine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ function updateFromSubgraph_StateMachine(csmc::CliqStateMachineContainer)
logCSM(
csmc,
"CSM-5 Clique $(csmc.cliq.id) finished, solveKey=$(csmc.solveKey)";
loglevel = Logging.Info,
loglevel = Logging.Debug,
)
return IncrementalInference.exitStateMachine
end
Expand Down
76 changes: 76 additions & 0 deletions src/Factors/GenericFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,82 @@ function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
return distanceTangent2Point(cf.factor.M, X, p, q)
end


## ======================================================================================
## adjoint factor - adjoint action applied to the measurement
## ======================================================================================
function Ad(::Union{typeof(SpecialEuclidean(2)), typeof(SpecialEuclidean(3))}, p, X)
t = p.x[1]
R = p.x[2]
v = X.x[1]
Ω = X.x[2]
ArrayPartition(-R*Ω*R'*t + R*v, R*Ω*R')
end

function Ad(::typeof(SpecialEuclidean(3)), p)
t = p.x[1]
R = p.x[2]
vcat(
hcat(R, skew(t)*R),
hcat(zero(SMatrix{3,3,Float64}), R)
)
end

function Ad(::typeof(SpecialEuclidean(2)), p)
t = p.x[1]
R = p.x[2]
vcat(
hcat(R, -SA[0 -1; 1 0]*t),
SA[0 0 1]
)
end

struct AdFactor{F <: AbstractManifoldMinimize} <: AbstractManifoldMinimize
factor::F
end

function (cf::CalcFactor{<:AdFactor})(Xϵ, p, q)
# M = getManifold(cf.factor)
# p,q ∈ M
# Xϵ ∈ TϵM
# ϵ = identity_element(M)
# transform measurement from TϵM to TpM (global to local coordinates)
# Adₚ⁻¹ = AdjointMatrix(M, p)⁻¹ = AdjointMatrix(M, p⁻¹)
# Xp = Adₚ⁻¹ * Xϵᵛ
# ad = Ad(M, inv(M, p))
# Xp = Ad(M, inv(M, p), Xϵ)
# Xp = adjoint_action(M, inv(M, p), Xϵ)
#TODO is vector transport supposed to be the same?
# Xp = vector_transport_to(M, ϵ, Xϵ, p)

# Transform measurement covariance
# ᵉΣₚ = Adₚ ᵖΣₚ Adₚᵀ
#TODO test if transforming sqrt_iΣ is the same as Σ
# Σ = ad * inv(cf.sqrt_iΣ^2) * ad'
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), sqrt(inv(Σ)))
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), ad * cf.sqrt_iΣ * ad')
Xp = Xϵ

child_cf = CalcFactorResidual(
cf.faclbl,
cf.factor.factor,
cf.varOrder,
cf.varOrderIdxs,
cf.meas,
cf.sqrt_iΣ,
cf.cache,
)
return child_cf(Xp, p, q)
end

getMeasurementParametric(f::AdFactor) = getMeasurementParametric(f.factor)

getManifold(f::AdFactor) = getManifold(f.factor)
function getSample(cf::CalcFactor{<:AdFactor})
M = getManifold(cf)
return sampleTangent(M, cf.factor.factor.Z)
end

## ======================================================================================
## ManifoldPrior
## ======================================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/services/ManifoldSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ function getSample(cf::CalcFactor{<:AbstractPrior})
end

function getSample(cf::CalcFactor{<:AbstractRelative})
M =getManifold(cf)
M = getManifold(cf)
if hasfield(typeof(cf.factor), :Z)
X = sampleTangent(M, cf.factor.Z)
else
Expand Down
195 changes: 184 additions & 11 deletions src/parametric/services/ParametricCSMFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Notes
- Parametric state machine function nr. 3
"""
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
function solveUp_ParametricStateMachine_Old(csmc::CliqStateMachineContainer)
infocsm(csmc, "Par-3, Solving Up")

setCliqueDrawColor!(csmc.cliq, "red")
Expand Down Expand Up @@ -96,6 +96,145 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
return waitForDown_StateMachine
end

# solve relatives ignoring any priors keeping `from` at ϵ
# if clique has priors : solve to get a prior on `from`
# send messages as factors or just the beliefs? for now factors
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
infocsm(csmc, "Par-3, Solving Up")

setCliqueDrawColor!(csmc.cliq, "red")
# csmc.drawtree ? drawTree(csmc.tree, show=false, filepath=joinpath(getSolverParams(csmc.dfg).logpath,"bt.pdf")) : nothing

msgfcts = Symbol[]

for (idx, upmsg) in getMessageBuffer(csmc.cliq).upRx #get cached messages taken from children saved in this clique
child_factors = addMsgFactors_Parametric!(csmc.cliqSubFg, upmsg, UpwardPass)
append!(msgfcts, getLabel.(child_factors)) # addMsgFactors_Parametric!
end
logCSM(csmc, "length mgsfcts=$(length(msgfcts))")
infocsm(csmc, "length mgsfcts=$(length(msgfcts))")

# store the cliqSubFg for later debugging
_dbgCSMSaveSubFG(csmc, "fg_beforeupsolve")

subfg = csmc.cliqSubFg

frontals = getCliqFrontalVarIds(csmc.cliq)
separators = getCliqSeparatorVarIds(csmc.cliq)

# if its a root do full solve
if length(getParent(csmc.tree, csmc.cliq)) == 0
# M, vartypeslist, lm_r, Σ = solve_RLM(subfg; is_sparse=false, finiteDiffCovariance=true)
autoinitParametric!(subfg)
M, vartypeslist, lm_r, Σ = solveGraphParametric!(subfg; is_sparse=false, finiteDiffCovariance=true, damping_term_min=1e-18)

else

# select first seperator as constant reference at the identity element
isempty(separators) && @warn "empty separators solving cliq $(csmc.cliq.id.value)" ls(subfg) lsf(subfg)
from = first(separators)
from_v = getVariable(subfg, from)
getSolverData(from_v, :parametric).val[1] = getPointIdentity(getVariableType(from_v))

#TODO handle priors
# Variables that are free to move
free_vars = [frontals; separators[2:end]]
# Solve for the free variables

@assert !isempty(lsf(subfg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(subfg)) lsf=$(lsf(subfg))"

# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from];)
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from]; finiteDiffCovariance=false, damping_term_min=1e-18)

end

# FIXME check solve convergence
if !true
@error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve" result
# propagate error to cleanly exit all cliques
putErrorUp(csmc)
if length(getParent(csmc.tree, csmc.cliq)) == 0
putErrorDown(csmc)
return IncrementalInference.exitStateMachine
end

return waitForDown_StateMachine
end

logCSM(csmc, "$(csmc.cliq.id): subfg solve converged sending messages")

# Pack results in massage factors

sigmas = extractMarginalsAP(M, vartypeslist, Σ)

# FIXME fix MsgRelativeType
relative_message_factors = MsgRelativeType();
for (i, to) in enumerate(vartypeslist)
if to in separators
#assume full dim factor
factype = selectFactorType(subfg, from, to)
# make S symetrical
# S = sigmas[i] # FIXME for some reason SMatrix is not invertable even though it is!!!!!!!!
S = Matrix(sigmas[i])# FIXME
S = (S + S') / 2
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")


M_to = getManifold(getVariableType(subfg, to))
ϵ = getPointIdentity(M_to)
μ = vee(M_to, ϵ, log(M_to, ϵ, lm_r[i]))

message_factor = AdFactor(factype(MvNormal(μ, S)))


# logCSM(csmc, "$(csmc.cliq.id): Z=$(getMeasurementParametric(message_factor))"; loglevel = Logging.Warn)

push!(relative_message_factors, (variables=[from, to], likelihood=message_factor))
end
end

# Done with solve delete factors
#TODO confirm, maybe don't delete mesage factors on subgraph, maybe delete if its priors, but not conditionals
# deleteMsgFactors!(csmc.cliqSubFg)

# store the cliqSubFg for later debugging
_dbgCSMSaveSubFG(csmc, "fg_afterupsolve")

# cliqueLikelihood = calculateMarginalCliqueLikelihood(vardict, Σ, varIds, cliqSeparatorVarIds)

#Fill in CliqueLikelihood
beliefMsg = LikelihoodMessage(;
sender = (; id = csmc.cliq.id.value, step = csmc._csm_iter),
status = UPSOLVED,
variableOrder = separators,
# cliqueLikelihood,
jointmsg = _MsgJointLikelihood(;relatives=relative_message_factors),
msgType = ParametricMessage(),
)

# @assert length(separators) <= 2 "TODO length(separators) = $(length(separators)) > 2 in clique $(csmc.cliq.id.value)"
@assert isempty(lsfPriors(csmc.cliqSubFg)) || csmc.cliq.id.value == 1 "TODO priors in clique $(csmc.cliq.id.value)"
# if length(lsfPriors(csmc.cliqSubFg)) > 0 || length(separators) > 2
# for si in cliqSeparatorVarIds
# vnd = getSolverData(getVariable(csmc.cliqSubFg, si), :parametric)
# beliefMsg.belief[si] = TreeBelief(deepcopy(vnd))
# end
# end

for e in getEdgesParent(csmc.tree, csmc.cliq)
logCSM(csmc, "$(csmc.cliq.id): put! on edge $(e)")
getMessageBuffer(csmc.cliq).upTx = deepcopy(beliefMsg)
putBeliefMessageUp!(csmc.tree, e, beliefMsg)
end

return waitForDown_StateMachine
end

global g_n = nothing

"""
$SIGNATURES

Expand All @@ -120,6 +259,15 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)")
vnd.val .= belief.val
vnd.bw .= belief.bw

p = belief.val[1]

S = belief.bw
S = (S + S') / 2
vnd.bw .= S

nd = MvNormal(getCoordinates(Main.Pose2, p), S)
addFactor!(csmc.cliqSubFg, [msym], Main.PriorPose2(nd))
end
end
end
Expand All @@ -132,23 +280,48 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
#only down solve if its not a root
if length(getParent(csmc.tree, csmc.cliq)) != 0
frontals = getCliqFrontalVarIds(csmc.cliq)
vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
# vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
#TEMP testing difference
# vardict, result = solveGraphParametric(csmc.cliqSubFg)
# Pack all results in variables
if result.g_converged || result.f_converged
@assert !isempty(lsf(csmc.cliqSubFg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(csmc.cliqSubFg)) lsf=$(lsf(csmc.cliqSubFg))"

# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(csmc.cliqSubFg, frontals; finiteDiffCovariance=false, damping_term_min=1e-18)
M, vartypeslist, lm_r, Σ = solve_RLM(csmc.cliqSubFg; finiteDiffCovariance=false, damping_term_min=1e-18)
sigmas = extractMarginalsAP(M, vartypeslist, Σ)

if true # TODO check for convergence result.g_converged || result.f_converged
logCSM(
csmc,
"$(csmc.cliq.id): subfg optim converged updating variables";
loglevel = Logging.Info,
loglevel = Logging.Debug,
)
for (v, val) in vardict
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
#Update subfg variables
vnd.val[1] = val.val
vnd.bw .= val.cov
for (i, v) in enumerate(vartypeslist)
if v in frontals
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)

S = Matrix(sigmas[i])# FIXME
S = (S + S') / 2
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")


#Update subfg variables
vnd.val[1] = lm_r[i]
vnd.bw .= S
end
end
# for (v, val) in vardict
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
# vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)

# #Update subfg variables
# vnd.val[1] = val.val
# vnd.bw .= val.cov
# end
else
@error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result
#propagate error to cleanly exit all cliques
Expand All @@ -169,7 +342,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
for fi in cliqFrontalVarIds
vnd = getSolverData(getVariable(csmc.cliqSubFg, fi), :parametric)
beliefMsg.belief[fi] = TreeBelief(vnd)
logCSM(csmc, "$(csmc.cliq.id): down message $fi : $beliefMsg"; loglevel = Logging.Info)
logCSM(csmc, "$(csmc.cliq.id): down message $fi"; beliefMsg=beliefMsg.belief[fi], loglevel = Logging.Debug)
end

# pass through the frontal variables that were sent from above
Expand Down
Loading
Loading