Last active
October 2, 2020 15:06
-
-
Save theogf/354ef8709db81c8fd3806586067e3c59 to your computer and use it in GitHub Desktop.
Compute semidiscrete wasserstein according to https://papers.nips.cc/paper/6566-stochastic-optimization-for-large-scale-optimal-transport.pdf
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using StatsFuns | |
function h(x::AbstractVector, v::AbstractVector, y::AbstractVector, ν::AbstractVector, ϵ::Real, c) | |
dot(v, ν) - ϵ * logsumexp((v - c.(Ref(x), y)) ./ ϵ .+ log.(ν)) - ϵ | |
end | |
function h(x::AbstractVector, v::AbstractVector, y::AbstractVector, ν::AbstractVector, ϵ::Int, c) | |
ϵ == 0 || error("ϵ has to be 0") | |
dot(v,ν) + mininum(c.(Ref(x), y) .- v) | |
end | |
function optim_v(μ, y::AbstractVector, ν::AbstractVector, η::Real, N::Int, ϵ::Real, c) | |
v = zero(ν); ṽ = zero(ν) | |
for k in 1:N | |
xₖ = rand(μ) | |
ṽ .+= η /√(k) * gradient(ν->h(xₖ, ṽ, y, ν, ϵ, c), ν)[1] | |
v = ṽ ./ k + (k - 1) / k * v | |
end | |
return v | |
end | |
function wasserstein_semidiscrete(μ, y, ν, ϵ, c=(x,y)->norm(x-y), η::Real = 0.1, N::Int = 100, N_MC::Int=2000) | |
v = optim_v(μ, y, ν, η, N, ϵ, c) | |
return mean(x->h(x, v, y, ν, ϵ, c), eachcol(rand(μ, N_MC))) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment