Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ojwoodford/0af1c93bb244fe527b5c73381c2d81fc to your computer and use it in GitHub Desktop.
Save ojwoodford/0af1c93bb244fe527b5c73381c2d81fc to your computer and use it in GitHub Desktop.
Comparison of two Julia non-linear least squares solver on a particular optimization problem
# Compare non-linear least squares solvers on the a dense affine alignment
using Distances, LinearAlgebra, Random
safe_dist(X, Y) = [!iszero(z) ? sqrt(z) : z for z in pairwise(SqEuclidean(), X, Y; dims=2)]
function generate_data(n)
# Create the data
num_points = 2 * n
# Create two sets of points
X = randn(n, num_points)
Y = randn(n, num_points)
# Create the ground truth affine warp
M = randn(n, n) * 0.1 + I
T = randn(n)
# Create the ground truth distance matrix
D = safe_dist(M * X .+ T, Y)
return X, Y, D
end
function resfun(x, X, Y, D)
# Shape the affine warp
n = size(X)[1]
M = reshape(x, n, n+1)
# Compute the residuals
return vec(D - safe_dist(view(M, :, 1:n) * X .+ view(M, :, n+1), Y))
end
################################################################################
# NLLS formulation
using NLLSsolver, Static
struct MyResidual <: NLLSsolver.AbstractResidual
X::Matrix{Float64}
Y::Matrix{Float64}
D::Matrix{Float64}
end
Base.eltype(::MyResidual) = Float64
NLLSsolver.ndeps(::MyResidual) = static(1) # Residual depends on 1 variables
NLLSsolver.nres(res::MyResidual) = length(res.D) # Residual variable length
NLLSsolver.varindices(::MyResidual) = 1
NLLSsolver.getvars(::MyResidual, vars::Vector) = (vars[1], )
NLLSsolver.computeresidual(res::MyResidual, x) = resfun(x, res.X, res.Y, res.D)
function nllssolver(X, Y, D)
# Create the problem
problem = NLLSsolver.NLLSProblem(Vector{Float64}, MyResidual)
n = size(X)[1]
start = zeros(n*n+n)
start[1:n+1:end] .= 1
NLLSsolver.addvariable!(problem, start)
NLLSsolver.addcost!(problem, MyResidual(X, Y, D))
return problem
end
function optimize(model::NLLSProblem)
result = NLLSsolver.optimize!(model, NLLSOptions(maxiters = 30000))
return result.niterations, result.startcost, result.bestcost
end
################################################################################
################################################################################
# LeastSquaresOptim formulation
import LeastSquaresOptim
function costfun!(out, x, X, Y, D)
out[:] .= resfun(x, X, Y, D)
end
function lso(X, Y, D)
n = size(X)[1]
start = zeros(n*n+n)
start[1:n+1:end] .= 1
return LeastSquaresOptim.LeastSquaresProblem(; x=start, f! = (out, x)->costfun!(out, x, X, Y, D), output_length=length(D), autodiff=:forward)
end
function optimize(model::LeastSquaresOptim.LeastSquaresProblem)
result = LeastSquaresOptim.optimize!(model, LeastSquaresOptim.LevenbergMarquardt(); iterations=30000)
return result.iterations, NaN64, result.ssr * 0.5
end
################################################################################
################################################################################
# Run the test and display results
using Plots, Printf
tickformatter(x) = @sprintf("%g", x)
function runtest(name, sizes, solvers)
result = Matrix{Float64}(undef, (4, length(sizes)))
p = plot()
# For each optimizer
for (label, constructor) in solvers
# First run the optimzation to compile everything
optimize(constructor(generate_data(sizes[1])...))
# Go over each problem, recording the time, iterations and start and end cost
for (i, n) in enumerate(sizes)
# Reset the random number generator
Random.seed!(i + 1234)
# Generate the data
X, Y, D = generate_data(n)
# Construct the problem
model = constructor(X, Y, D)
# Optimize
result[1,i] = @elapsed res = optimize(model)
result[2,i] = res[1]
result[3,i] = res[2]
result[4,i] = res[3]
if res[3] > 1.e-10
cost = res[3]
println("$label on size $n converged to a cost of $cost")
# result[1,i] = NaN64
end
end
# Plot the graphs
plot!(p, vec(sizes), vec(result[1,:]), label=label)
end
yaxis!(p, minorgrid=true, formatter=tickformatter)
xaxis!(p, minorgrid=true, formatter=tickformatter)
plot!(p, legend=:topleft, yscale=:log10, xscale=:log2)
title!(p, "Speed comparison: $name")
xlabel!(p, "Problem size")
ylabel!(p, "Optimization time (s)")
display(p)
end
################################################################################
runtest("Large dense problem", [2 4 8 16 32], ["LeastSquaresOptim LM" => lso, "NLLSsolver LM" => nllssolver])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment