Created
May 24, 2021 19:09
-
-
Save devmotion/6b987174bce2bedf18b074358be80198 to your computer and use it in GitHub Desktop.
Discrete OT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions | |
using SparseArrays | |
using LinearAlgebra | |
using StatsBase | |
function _ot_cost_plan(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; get=:plan) | |
len_μ = length(μ.p) | |
len_ν = length(ν.p) | |
wi = μ.p[1] | |
wj = ν.p[1] | |
if get == :plan | |
γ = spzeros(Base.promote_eltype(μ.p, ν.p), len_μ, len_ν) | |
elseif get == :cost | |
cost = c(μ.support[1], ν.support[1]) * min(wi, wj) | |
end | |
i, j = 1, 1 | |
while true | |
if (wi < wj || j == len_ν) | |
if get == :plan | |
γ[i, j] = wi | |
elseif (get == :cost && i + j > 2) # skip the first case, already computed | |
cost += c(μ.support[i], ν.support[j]) * wi | |
end | |
i += 1 | |
if i == len_μ + 1 | |
break | |
end | |
wj -= wi | |
wi = μ.p[i] | |
else | |
if get == :plan | |
γ[i, j] = wj | |
elseif (get == :cost && i + j > 2) # skip the first case, already computed | |
cost += c(μ.support[i], ν.support[j]) * wj | |
end | |
j += 1 | |
if j == len_ν + 1 | |
break | |
end | |
wi -= wj | |
wj = ν.p[j] | |
end | |
end | |
if get == :plan | |
return γ | |
elseif get == :cost | |
return cost | |
end | |
end | |
struct DiscreteOTIterator{T,M,N} | |
mu::M | |
nu::N | |
end | |
function DiscreteOTIterator(mu, nu) | |
T = Base.promote_eltype(mu, nu) | |
return DiscreteOTIterator{T,typeof(mu),typeof(nu)}(mu, nu) | |
end | |
Base.IteratorEltype(::Type{<:DiscreteOTIterator}) = Base.HasEltype() | |
Base.IteratorSize(::Type{<:DiscreteOTIterator}) = Base.SizeUnknown() | |
Base.eltype(::Type{<:DiscreteOTIterator{T}}) where {T} = Tuple{Int,Int,T} | |
function Base.iterate( | |
d::DiscreteOTIterator{T}, | |
(i, j, mu_next, nu_next)=(1, 1, iterate(d.mu), iterate(d.nu)) | |
) where {T} | |
if mu_next === nothing || nu_next === nothing | |
return nothing | |
end | |
mu_iter, mu_state = mu_next | |
nu_iter, nu_state = nu_next | |
min_iter, max_iter = minmax(mu_iter, nu_iter) | |
iter = (i, j, min_iter) | |
diff = max_iter - min_iter | |
state = if mu_iter < max_iter | |
(i + 1, j, iterate(d.mu, mu_state), (diff, nu_state)) | |
else | |
(i, j + 1, (diff, mu_state), iterate(d.nu, nu_state)) | |
end | |
return iter, state | |
end | |
mu_support = randn(200) | |
nu_support = randn(250) | |
mu_probs = rand(200) | |
mu_probs ./= sum(mu_probs) | |
nu_probs = rand(250) | |
nu_probs ./= sum(nu_probs) | |
c(x, y) = abs(x - y) | |
function ot_plan(_, mu::DiscreteNonParametric, nu::DiscreteNonParametric) | |
probs_mu = probs(mu) | |
probs_nu = probs(nu) | |
iter = DiscreteOTIterator(probs_mu, probs_nu) | |
I = Int[] | |
J = Int[] | |
W = Vector{Base.promote_eltype(probs_mu, probs_nu)}(undef, 0) | |
m = max(length(probs_mu), length(probs_nu)) | |
sizehint!(I, m) | |
sizehint!(J, m) | |
sizehint!(W, m) | |
for (i, j, w) in iter | |
push!(I, i) | |
push!(J, j) | |
push!(W, w) | |
end | |
return sparse(I, J, W, length(probs_mu), length(probs_nu)) | |
end | |
plan = ot_plan(c, DiscreteNonParametric(mu_support, mu_probs), DiscreteNonParametric(nu_support, nu_probs)) | |
plan2 = _ot_cost_plan(c, DiscreteNonParametric(mu_support, mu_probs), DiscreteNonParametric(nu_support, nu_probs)) | |
plan ≈ plan2 | |
function ot_cost(c, mu::DiscreteNonParametric, nu::DiscreteNonParametric; plan=nothing) | |
return _ot_cost(c, mu, nu, plan) | |
end | |
function _ot_cost(c, mu, nu, ::Nothing) | |
probs_mu = probs(mu) | |
probs_nu = probs(nu) | |
support_mu = support(mu) | |
support_nu = support(nu) | |
iter = DiscreteOTIterator(probs_mu, probs_nu) | |
return sum(c(support_mu[i], support_nu[j]) * w for (i, j, w) in iter) | |
end | |
function _ot_cost(c, mu, nu, plan::SparseMatrixCSC) | |
support_mu = support(mu) | |
support_nu = support(nu) | |
I, J, W = findnz(plan) | |
return sum(c(support_mu[i], support_nu[j]) * w for (i, j, w) in zip(I, J, W)) | |
end | |
_ot_cost(c, mu, nu, plan) = dot(plan, pairwise(c, support(mu), support(nu))) | |
cost = ot_cost( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs) | |
) | |
cost2 = _ot_cost_plan( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs); | |
get=:cost | |
) | |
cost == cost2 | |
cost ≈ ot_cost( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs); | |
plan=plan, | |
) | |
cost ≈ ot_cost( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs); | |
plan=Matrix(plan), | |
) | |
@code_warntype ot_plan(c, DiscreteNonParametric(mu_support, mu_probs), DiscreteNonParametric(nu_support, nu_probs)) | |
@code_warntype ot_cost( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs) | |
) | |
@code_warntype ot_cost( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs); | |
plan=plan, | |
) | |
@code_warntype ot_cost( | |
c, | |
DiscreteNonParametric(mu_support, mu_probs), | |
DiscreteNonParametric(nu_support, nu_probs); | |
plan=Matrix(plan), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment