Skip to content

Commit

Permalink
add inplace version
Browse files Browse the repository at this point in the history
  • Loading branch information
francescoalemanno committed Mar 23, 2022
1 parent 220204a commit 70dcdf7
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FixedPoint"
uuid = "3325f569-5a18-4e7d-8356-246b69339eea"
authors = ["Francesco Alemanno <[email protected]>"]
version = "1.0.0"
version = "1.0.1"

[compat]
julia = "^1.5"
Expand Down
63 changes: 62 additions & 1 deletion src/FixedPoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,66 @@ function afps(
end
(x = x_n, error = grad_norm(f(x_n) - x_n), iters = runs)
end
export afps


"""
afps!(f!, x; iters::Int = 5000, vel::Float64 = 0.9, ep::Float64 = 0.01, tol::Float64 = 1e-12, grad_norm=x->maximum(abs,x))
solve equation `f(x) = x` according to:
`f!` : inplace version of function to find fixed point for, calling `f!(out,x)` should amount to writing `out = f(x)`
`x` : initial condition, ideally it should be close to the final solution
`vel` : amount of Nesterov acceleration in [0,1]
`ep` : learning rate, typically in ]0,1[
`tol` : absolute tolerance on |f(x)-x|
`grad_norm` : function to evaluate the norm for |f(x)-x|
returns a named tuple (x, error, iters) where:
`x` : is the solution found for f(x)=x
`error` : is the norm of f(x)-x at the solution point
`iters` : total number of iterations performed
"""
function afps!(
f!::Fun,
x_n::Mat;
iters::Int = 5000,
vel::T = 0.9,
ep::T = 0.01,
tol::T = 1e-12,
grad_norm = x -> maximum(abs, x),
) where {T <: Number, Mat<:Union{AbstractArray{T},T}, Fun <: Function}
v_n = zero(x_n)
trial = zero(x_n)
g = zero(x_n)
β = vel
ϵ = ep
runs = 0
for _ = 1:iters
#prediction using velocity
x_n .+= β .* v_n
#eval gradient
f!(g,x_n)
g .-= x_n
#correction
x_n .+= ϵ .* g
#update velocity
v_n .= β .* v_n .+ ϵ .* g
runs += 1
if grad_norm(g) < tol
break
end
end
(x = x_n, error = grad_norm(g), iters = runs)
end


export afps,afps!
end
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,30 @@ end
s = afps(x -> 2 - x^2 + x, 1.3)
@test s.x 2
end


@testset "Inplace Vector Equation" begin
Ts = LinRange(0.01, 2.0, 500)
βs = 1 ./ Ts
function f!(out,x)
@. out = tanh(βs * x)
end
x = zero(βs) .+ 1
mag = afps!(
f!,
x,
grad_norm = x -> maximum(abs, x),
iters = 5000,
)
@test mag.error < 1e-4
@test maximum(abs, x .- tanh.(βs .* x)) < 1e-4
end

@testset "Inplace Equation" begin
x = [1.3]
function f!(out, x)
out[1] = 2 - x[1]^2 + x[1]
end
s = afps!(f!, x)
@test s.x[1] 2
end

0 comments on commit 70dcdf7

Please sign in to comment.