Skip to content

Commit

Permalink
Merge pull request #809 from arismavridis/patch-1
Browse files Browse the repository at this point in the history
Added precompilation for nonnegative least squares
  • Loading branch information
Vaibhavdixit02 authored Aug 29, 2024
2 parents b2e0d1d + 38471df commit 9c4f74d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/OptimizationOptimJL/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"

[compat]
Optim = "1"
Expand Down
30 changes: 30 additions & 0 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,34 @@ function SciMLBase.__solve(cache::OptimizationCache{
stats = stats)
end

using PrecompileTools
PrecompileTools.@compile_workload begin

function obj_f(x, p)
A = p[1]
b = p[2]
return sum((A * x - b) .^ 2)
end

function solve_nonnegative_least_squares(A, b, solver)

optf = Optimization.OptimizationFunction(obj_f, Optimization.AutoForwardDiff())
prob = Optimization.OptimizationProblem(optf, ones(size(A, 2)), (A, b), lb=zeros(size(A, 2)), ub=Inf * ones(size(A, 2)))
x = OptimizationOptimJL.solve(prob, solver, maxiters=5000, maxtime=100)

return x
end

solver_list = [OptimizationOptimJL.LBFGS(),
OptimizationOptimJL.ConjugateGradient(),
OptimizationOptimJL.GradientDescent(),
OptimizationOptimJL.BFGS()]

for solver in solver_list
x = solve_nonnegative_least_squares(rand(4, 4), rand(4), solver)
x = solve_nonnegative_least_squares(rand(35, 35), rand(35), solver)
x = solve_nonnegative_least_squares(rand(35, 10), rand(35), solver)
end
end

end

0 comments on commit 9c4f74d

Please sign in to comment.