Last active
April 6, 2023 01:05
-
-
Save MilesCranmer/a8277006c6e6411ddfa9a28abdc4342b to your computer and use it in GitHub Desktop.
Compare forward-mode and reverse-mode differentiation over parameter #
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using BenchmarkTools | |
using ForwardDiff | |
using ReverseDiff | |
using Random | |
using Plots | |
using Statistics: mean, quantile, std | |
using Measurements | |
using Printf: @sprintf | |
using Colors | |
# Okabe & Ito colors: | |
colors = [ | |
RGB(0 / 255, 114 / 255, 178 / 255), # blue | |
RGB(230 / 255, 159 / 255, 0 / 255), # orange | |
RGB(0 / 255, 158 / 255, 115 / 255), # green | |
RGB(204 / 255, 121 / 255, 167 / 255), # reddish purple | |
RGB(86 / 255, 180 / 255, 233 / 255), # sky blue | |
RGB(213 / 255, 94 / 255, 0 / 255), # vermillion | |
RGB(240 / 255, 228 / 255, 66 / 255), # yellow | |
] | |
# Specialize number of loops (just in case) | |
function _f(x, params, ::Val{n}) where {n} | |
for i = 1:n | |
x = @. cos(x + params[i]) | |
end | |
return x | |
end | |
function f(x, params) | |
return _f(x, params, Val(length(params))) | |
end | |
const suite = BenchmarkGroup() | |
suite["forward"] = BenchmarkGroup() | |
suite["reverse"] = BenchmarkGroup() | |
function forward(x, params) | |
return ForwardDiff.gradient(p -> sum(f(x, p)), params) | |
end | |
function reverse(x, params) | |
return ReverseDiff.gradient(p -> sum(f(x, p)), params) | |
end | |
# Warmup: | |
forward([1.0], [1.0]) | |
reverse([1.0], [1.0]) | |
all_num_params = [round(Int, 10^l) for l = 0.0:0.25:3.0] | |
all_num_x = [round(Int, 10^l) for l = 0:3] | |
for num_params in all_num_params, num_x in all_num_x | |
# Get benchmark information: | |
Random.seed!(0) | |
if !haskey(suite["forward"], num_params) | |
suite["forward"][num_params] = BenchmarkGroup() | |
suite["reverse"][num_params] = BenchmarkGroup() | |
end | |
# At num_params = 1, we want num_evals=100. | |
# At num_params = 1000, we want num_evals=10. | |
# Should be linear in log-space. | |
num_evals_log = 2.0 - log10(num_params) / 3.0 | |
num_evals = round(Int, 10^num_evals_log) | |
suite["forward"][num_params][num_x] = | |
@benchmarkable $forward(x, params) evals = num_evals samples = 1000 setup = | |
(x = rand($num_x) .* 6.28; params = rand($num_params) .* 6.28) | |
suite["reverse"][num_params][num_x] = | |
@benchmarkable $reverse(x, params) evals = num_evals samples = 1000 setup = | |
(x = rand($num_x) .* 6.28; params = rand($num_params) .* 6.28) | |
end | |
res = run(suite, verbose = true) | |
res_matrix = | |
[res[k][np][nx] for np in all_num_params, nx in all_num_x, k in ["forward", "reverse"]] | |
res_agg = (x -> mean(log10.(x.times)) ± std(log10.(x.times))).(res_matrix) | |
ratio_forward_to_reverse = res_agg[:, :, 1] .- res_agg[:, :, 2] | |
# Set the backend to use with the desired resolution | |
gr(size = (600, 400), dpi = 300) | |
# Line plot, with x-axis the number of params, and y-axis the ratio. | |
# Different lines for each x. | |
# We also want to have log-scale on both axes. | |
# We also want dpi of 300 | |
for (ix, nx) in enumerate(all_num_x) | |
plotter = ix == 1 ? plot : plot! | |
# Plot these values (Measurement{Float64}, so they automatically get error bars) | |
# We want to color the error bars the same as the line! | |
plotter( | |
all_num_params, | |
ratio_forward_to_reverse[:, ix, :], | |
label = "nₓ = $nx", | |
xscale = :log10, | |
color = colors[ix], | |
markerstrokecolor = colors[ix], | |
) | |
end | |
# Plot y=1 line: (put in background) | |
# Put in background: | |
plot!( | |
all_num_params, | |
zeros(size(all_num_params)), | |
label = "", | |
color = :black, | |
alpha = 0.1, | |
line = (:dash, 2.0), | |
) | |
# Legend: | |
plot!(legend = :topleft) | |
# Add text to bottom half: "Forward better": | |
annotate!(3, -0.2, text("Forward better", :black, 8)) | |
annotate!(3, 0.2, text("Reverse better", :black, 8)) | |
# Explain errors: | |
annotate!(10^2.5, -1.3, text("Errors show 1σ", :black, 8)) | |
# Label x-axis: | |
xlabel!("Number of parameters") | |
# Label y-axis: | |
ylabel!("Δt[ForwardDiff] / Δt[ReverseDiff]") | |
# x-ticks at 10^0, 10^1, 10^2, 10^3: | |
xticks!(10.0 .^ (0:3)) | |
# Set y-ticks manually: | |
new_ticks = [0.1, 1.0, 10.0] | |
yticks!((log10.(new_ticks), string.(new_ticks))) | |
# Save with dpi of 300: | |
savefig("forward_vs_reverse.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment