Last active
May 13, 2018 00:30
-
-
Save fanjin-z/b0233efb09602559519e64c805b507f6 to your computer and use it in GitHub Desktop.
An Python Implementation of Hungarian Algorithm. This implementation is based on http://csclab.murraystate.edu/~bob.pilgrim/445/munkres.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class hungrary(): | |
def __init__(self, weight): | |
self.n = weight.shape[0] | |
self.w = np.copy(weight) | |
# cost matrix | |
self.c = np.copy(weight) | |
self.m = np.zeros((self.n, self.n), dtype=int) | |
# record row and col covers | |
self.RowCover = np.zeros((self.n), dtype=bool) | |
self.ColCover = np.zeros((self.n), dtype=bool) | |
# record augment paths | |
self.path = np.zeros((2*self.n, 2), dtype=int) | |
# main program, run the algo through steps | |
def run_hungrary(self): | |
done = False | |
step = 1 | |
while not done: | |
if step == 1: | |
step = self.step1() | |
elif step == 2: | |
step = self.step2() | |
elif step == 3: | |
step = self.step3() | |
elif step == 4: | |
step = self.step4() | |
elif step == 5: | |
step = self.step5() | |
elif step == 6: | |
step = self.step6() | |
elif step == 7: | |
done = True | |
# Each row subtract smallest elements | |
def step1(self): | |
self.c -= np.min(self.c, axis=1, keepdims=True) | |
return 2 | |
# star zeros | |
def step2(self): | |
for u in range(self.n): | |
for v in range(self.n): | |
if self.c[u,v] == 0 and not self.RowCover[u] and not self.ColCover[v]: | |
self.m[u, v] = 1 | |
self.RowCover[u] = True | |
self.ColCover[v] = True | |
break | |
self.clear_covers() | |
return 3 | |
# cover cols with starred zeros. check if done | |
def step3(self): | |
for u in range(self.n): | |
for v in range(self.n): | |
if self.m[u, v] == 1: | |
self.ColCover[v] = True | |
colcnt = np.sum(self.ColCover) | |
if colcnt >= self.n: | |
return 7 | |
else: | |
return 4 | |
# find noncovered zero and prime it (starred as 2) | |
def step4(self): | |
while True: | |
row, col = self.find_a_zero() | |
if row == -1: | |
return 6 | |
else: | |
self.m[row, col] = 2 | |
if self.star_in_row(row): | |
col = self.find_star_in_row(row) | |
self.RowCover[row] = True | |
self.ColCover[col] = False | |
else: | |
self.path_row_0 = row | |
self.path_col_0 = col | |
return 5 | |
# use augment algo to increase matches | |
def step5(self): | |
done = False | |
self.path_count = 1 | |
self.path[self.path_count-1, 0] = self.path_row_0 | |
self.path[self.path_count-1, 1] = self.path_col_0 | |
while not done: | |
row = self.find_star_in_col(self.path[self.path_count-1, 1]) | |
if row > -1: | |
self.path_count += 1 | |
self.path[self.path_count-1, 0] = row | |
self.path[self.path_count-1, 1] = self.path[self.path_count-2, 1] | |
else: | |
done = True | |
if not done: | |
col = self.find_prime_in_row(self.path[self.path_count-1, 0]) | |
self.path_count += 1 | |
self.path[self.path_count-1, 0] = self.path[self.path_count-2, 0] | |
self.path[self.path_count-1, 1] = col | |
self.augment_path() | |
self.clear_covers() | |
self.erase_prime() | |
return 3 | |
# add minval val to double covered elements and subtract it to noncovered elements | |
def step6(self): | |
minval = self.find_smallest() | |
for u in range(self.n): | |
for v in range(self.n): | |
if self.RowCover[u]: | |
self.c[u,v] += minval | |
if not self.ColCover[v]: | |
self.c[u,v] -= minval | |
return 4 | |
# find first uncovered zero | |
def find_a_zero(self): | |
for u in range(self.n): | |
for v in range(self.n): | |
if self.c[u,v] == 0 and not self.RowCover[u] and not self.ColCover[v]: | |
return u, v | |
return -1, -1 | |
def star_in_row(self, row): | |
for v in range(self.n): | |
if self.m[row, v] == 1: | |
return True | |
return False | |
def find_star_in_row(self, row): | |
for v in range(self.n): | |
if self.m[row, v] == 1: | |
return v | |
return -1 | |
def find_star_in_col(self, col): | |
for u in range(self.n): | |
if self.m[u, col] == 1: | |
return u | |
return -1 | |
def find_prime_in_row(self, row): | |
for v in range(self.n): | |
if self.m[row, v] == 2: | |
return v | |
return -1 | |
def augment_path(self): | |
for p in range(self.path_count): | |
if self.m[self.path[p,0], self.path[p,1]] == 1: | |
self.m[self.path[p,0], self.path[p,1]] = 0 | |
else: | |
self.m[self.path[p,0], self.path[p,1]] = 1 | |
def clear_covers(self): | |
self.RowCover = np.zeros((self.n), dtype=bool) | |
self.ColCover = np.zeros((self.n), dtype=bool) | |
def erase_prime(self): | |
for u in range(self.n): | |
for v in range(self.n): | |
if self.m[u,v] == 2: | |
self.m[u,v] = 0 | |
def find_smallest(self): | |
minval = np.max(self.c) | |
for u in range(self.n): | |
for v in range(self.n): | |
if self.c[u,v] < minval and not self.RowCover[u] and not self.ColCover[v]: | |
minval = self.c[u,v] | |
return minval | |
from scipy.optimize import linear_sum_assignment | |
# Check correctness | |
distance_mat = np.random.rand(10,10) | |
H = hungrary(distance_mat) | |
H.run_hungrary() | |
# EMD computed by my hungarian algorithm | |
np.sum(H.w * H.m) | |
# EMD computed by scipy linear_sum_assignment# EMD c | |
row_ind, col_ind = linear_sum_assignment(distance_mat) | |
distance_mat[row_ind, col_ind].sum() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment