-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #544 from JuliaRobotics/maint/2Q20/mvrex
standardizing ros and slam frontend, wip pynn
- Loading branch information
Showing
22 changed files
with
1,195 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
examples/learning/hybrid/NeuralPose2Pose2/FluxModelsPose2Pose2.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# FluxModelsPose2Pose2 | ||
|
||
using Random, Statistics | ||
using DistributedFactorGraphs, TransformUtils | ||
@everywhere using Random, Statistics | ||
@everywhere using DistributedFactorGraphs, TransformUtils | ||
|
||
using Flux | ||
@everywhere using Flux | ||
|
||
|
||
@everywhere begin | ||
import Base: convert | ||
import IncrementalInference: getSample | ||
|
||
struct FluxModelsPose2Pose2{P,D<:AbstractArray,M<:SamplableBelief} <: FunctorPairwise | ||
predictFnc::P | ||
joyVelData::D | ||
naiveModel::M | ||
naiveFrac::Float64 | ||
Zij::Pose2Pose2 | ||
specialSampler::Function # special keyword field name used to invoke 'specialSampler' logic | ||
end | ||
|
||
|
||
function sampleNeuralPose2(nfb::FluxModelsPose2Pose2, | ||
N::Int, | ||
fmd::FactorMetadata, | ||
Xi::DFGVariable, | ||
Xj::DFGVariable)::Tuple | ||
# | ||
|
||
# calculate naive model and Predictive fraction of samples, respectively | ||
Nn = round(Int, nfb.naiveFrac*N) | ||
# calculate desired number of predicted values | ||
Np = N - Nn | ||
len = size(nfb.joyVelData,1) # expect this to be only 25 at developmenttime, likely to change | ||
|
||
# model samples (all for theta at this time) | ||
smpls_mAll = rand(nfb.naiveModel, N) | ||
|
||
# sample predictive fraction | ||
iT, jT = getTimestamp(Xi), getTimestamp(Xj) | ||
iPts, jPts = (getKDE(Xi) |> getPoints), (getKDE(Xj) |> getPoints) | ||
@assert size(jPts,2) == size(iPts,2) "sampleNeuralPose2 can currently only evaluate equal population size variables" | ||
|
||
# calculate an average velocity component | ||
DT = jT - iT | ||
DXY = (@view jPts[1:2,:]) - (@view jPts[1:2,:]) | ||
# rotate delta position from world to local iX frame | ||
for i in 1:size(iPts,2) | ||
DXY[1:2,i] .= TransformUtils.R(iPts[3,i])'*DXY[1:2,i] | ||
end | ||
# replace delta (velocity) values for this sampling | ||
mVXY = Statistics.mean(DXY, dims=2) | ||
# divide time to get velocity | ||
mVXY ./= 1e-3*DT.value | ||
mVXY[1] = isnan(mVXY[1]) ? 0.0 : mVXY[1] | ||
mVXY[2] = isnan(mVXY[2]) ? 0.0 : mVXY[2] | ||
|
||
for i in 1:len | ||
nfb.joyVelData[i,3:4] = mVXY | ||
end | ||
# and predict | ||
# A = [rand(4) for i in 1:25] | ||
smpls_pAll = nfb.predictFnc(nfb.joyVelData) | ||
if size(smpls_pAll,2) == 2 | ||
smpls_pAll = hcat(smpls_pAll, zeros(size(smpls_pAll,1))) | ||
end | ||
|
||
# number of predictors to choose from, and choose random subset | ||
Npreds = size(smpls_pAll,1) | ||
allPreds = 1:Npreds |> collect | ||
# randomly select particles for prediction (with possible duplicates forwhen Np > size(iPts,2)) | ||
Npp = Np < Npreds ? Np : Npreds | ||
Nnn = N - Npp | ||
selPreds = @view shuffle!(allPreds)[1:Npp] # TODO better in-place | ||
smpls_p = smpls_pAll[selPreds,:] # sample per row?? | ||
smpls_p[:,3] .= smpls_mAll[3,Nnn+1:N] # use naive delta theta at this time | ||
|
||
# naive fraction | ||
smpls_n = @view smpls_mAll[:, 1:Nnn] # sample per column | ||
# join and shuffle predicted odo values | ||
shfSmpl = shuffle!(1:N |> collect) | ||
smpls = hcat(smpls_n, smpls_p')[:,shfSmpl] | ||
|
||
# @show Statistics.mean(smpls, dims=2) | ||
# @show maximum(smpls, dims=2) | ||
# @show minimum(smpls, dims=2) | ||
|
||
return (smpls, ) | ||
end | ||
|
||
# Convenience function to help call the right constuctor | ||
FluxModelsPose2Pose2(nn::P, | ||
jvd::D, | ||
md::M, | ||
naiveFrac::Float64=0.4, | ||
ss::Function=sampleNeuralPose2) where {P, M <: SamplableBelief, D <: AbstractMatrix} = FluxModelsPose2Pose2{P,D,M}(nn,jvd,md,naiveFrac,Pose2Pose2(MvNormal(zeros(3),diagm(ones(3)))),ss ) | ||
# | ||
|
||
|
||
function (nfb::FluxModelsPose2Pose2)( | ||
res::AbstractArray{<:Real}, | ||
userdata::FactorMetadata, | ||
idx::Int, | ||
meas::Tuple{AbstractArray{<:Real},}, | ||
Xi::AbstractArray{<:Real,2}, | ||
Xj::AbstractArray{<:Real,2} ) | ||
# | ||
nfb.Zij(res,userdata,idx,meas,Xi,Xj) | ||
nothing | ||
end | ||
|
||
|
||
|
||
## packing converters | ||
|
||
struct PackedFluxModelsPose2Pose2 <: IncrementalInference.PackedInferenceType | ||
joyVelData::Matrix{Float64} | ||
naiveModel::String | ||
naiveFrac::Float64 | ||
end | ||
|
||
|
||
function convert(::Type{FluxModelsPose2Pose2}, d::PackedFluxModelsPose2Pose2) | ||
FluxModelsPose2Pose2(PyTFOdoPredictorPoint2,d.joyVelData,extractdistribution(d.naiveModel),d.naiveFrac) | ||
end | ||
|
||
function convert(::Type{PackedFluxModelsPose2Pose2}, d::FluxModelsPose2Pose2) | ||
PackedFluxModelsPose2Pose2(d.joyVelData, string(d.naiveModel), d.naiveFrac) | ||
end | ||
|
||
end # everywhere | ||
|
||
# |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
using DelimitedFiles | ||
using Flux | ||
@everywhere using Flux | ||
|
||
# load a specialized model format | ||
function loadPyNNTxt(dest::AbstractString) | ||
mw = Vector{Array{Float32}}() | ||
@show files = readdir(dest) | ||
for f in files | ||
push!(mw, readdlm(joinpath(dest,f))) | ||
end | ||
return mw | ||
end | ||
|
||
|
||
## Utility functions to take values from tf | ||
|
||
|
||
function buildPyNNModel_01_FromElements(W1::AbstractMatrix{<:Real}=zeros(4,8), | ||
b1::AbstractVector{<:Real}=zeros(8), | ||
W2::AbstractMatrix{<:Real}=zeros(8,48), | ||
b2::AbstractVector{<:Real}=zeros(8), | ||
W3::AbstractMatrix{<:Real}=zeros(2,8), | ||
b3::AbstractVector{<:Real}=zeros(2)) | ||
# | ||
# W1 = randn(Float32, 4,8) | ||
# b1 = randn(Float32,8) | ||
modjl = Chain( | ||
x -> (x*W1)' .+ b1 .|> relu, | ||
x -> reshape(x', 25,8,1), | ||
x -> maxpool(x, PoolDims(x, 4)), | ||
# x -> reshape(x[:,:,1]',1,:), | ||
x -> reshape(x[:,:,1]',:), | ||
Dense(48,8,relu), | ||
Dense(8,2) | ||
) | ||
|
||
modjl[5].W .= W2 | ||
modjl[5].b .= b2 | ||
|
||
modjl[6].W .= W3 | ||
modjl[6].b .= b3 | ||
|
||
return modjl | ||
end | ||
|
||
# As loaded from tensorflow get_weights | ||
# Super specialized function | ||
function buildPyNNModel_01_FromWeights(pywe) | ||
buildPyNNModel_01_FromElements(pywe[1], pywe[2][:], pywe[3]', pywe[4][:], pywe[5]', pywe[6][:]) | ||
end | ||
|
||
# convenience function to load specific model format from tensorflow | ||
function loadTfModelIntoFlux(dest::AbstractString) | ||
weights = loadPyNNTxt(dest::AbstractString) | ||
buildPyNNModel_01_FromWeights(weights) | ||
end |
Oops, something went wrong.