Created
December 15, 2020 21:23
-
-
Save Darkyenus/d8349307fcdb35fed769ef4081e6d313 to your computer and use it in GitHub Desktop.
Hungarian algorithm implementation for weighted bipartite matching
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
MIT License | |
Copyright (c) 2017 James Payor | |
Copyright (c) 2020 Jan Polák (Java rewrite) | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
*/ | |
import com.badlogic.gdx.utils.Array; | |
import com.badlogic.gdx.utils.Queue; | |
import java.util.Arrays; | |
/** | |
* Hungarian algorithm implementation for weighted bipartite matching. | |
* <p> | |
* Java rewrite of https://github.com/jamespayor/weighted-bipartite-perfect-matching/tree/259ce65f16b67390e0ae4e0d3519543085b9f109 | |
* | |
* Not particularly GC optimized, may produce a lot of garbage. | |
*/ | |
public final class Hungarian { | |
private static final int oo = Integer.MAX_VALUE; | |
private static final int UNMATCHED = -1; | |
public static final class WeightedBipartiteEdge { | |
final int left, right, cost; | |
public WeightedBipartiteEdge(int left, int right, int cost) { | |
this.left = left; | |
this.right = right; | |
this.cost = cost; | |
} | |
} | |
private static final class LeftEdge implements Comparable<LeftEdge> { | |
final int right, cost; | |
public LeftEdge(int right, int cost) { | |
this.right = right; | |
this.cost = cost; | |
} | |
@Override | |
public int compareTo(Hungarian.LeftEdge otherEdge) { | |
if (right < otherEdge.right || (right == otherEdge.right && cost < otherEdge.cost)) { | |
return -1; | |
} else if (right == otherEdge.right && cost == otherEdge.cost) { | |
return 0; | |
} else { | |
return 1; | |
} | |
} | |
} | |
/** | |
* Given the number of nodes on each side of the bipartite graph and a list of edges, returns a minimum-weight perfect matching. | |
* If a matching is found, returns a length-n vector, giving the nodes on the right that the left nodes are matched to. | |
* | |
* (Note: Edges with endpoints out of the range [0, n) are ignored.) | |
* | |
* @param n total amount of nodes on each side of the bipartite graph | |
* @param allEdges all defined edges | |
* @return null if no matching exists or array where [a] == b where a is index to left node and b index to right node | |
*/ | |
public static int[] hungarianMinimumWeightPerfectMatching(final int n, WeightedBipartiteEdge[] allEdges) { | |
// Edge lists for each left node. | |
Array<LeftEdge>[] leftEdges = new Array[n]; | |
for (int i = 0; i < n; i++) { | |
leftEdges[i] = new Array<>(); | |
} | |
//region Edge list initialization | |
// Initialize edge lists for each left node, based on the incoming set of edges. | |
// While we're at it, we check that every node has at least one associated edge. | |
// (Note: We filter out the edges that invalidly refer to a node on the left or right outside [0, n).) | |
{ | |
int[] leftEdgeCounts = new int[n]; | |
int[] rightEdgeCounts = new int[n]; | |
for (final WeightedBipartiteEdge edge : allEdges) { | |
if (edge.left >= 0 && edge.left < n) { | |
++leftEdgeCounts[edge.left]; | |
} | |
if (edge.right >= 0 && edge.right < n) { | |
++rightEdgeCounts[edge.right]; | |
} | |
} | |
for (int i = 0; i < n; i++) { | |
if (leftEdgeCounts[i] == 0 || rightEdgeCounts[i] == 0) { | |
// No matching will be possible. | |
return null; | |
} | |
} | |
// Probably unnecessary, but reserve the required space for each node, just because? | |
for (int i = 0; i < n; i++) { | |
leftEdges[i].ensureCapacity(leftEdgeCounts[i]); | |
} | |
} | |
// Actually add to the edge lists now. | |
for (final WeightedBipartiteEdge edge : allEdges) { | |
if (edge.left >= 0 && edge.left < n && edge.right >= 0 && edge.right < n) { | |
leftEdges[edge.left].add(new LeftEdge(edge.right, edge.cost)); | |
} | |
} | |
// Sort the edge lists, and remove duplicate edges (keep the edge with smallest cost). | |
for (int i = 0; i < n; i++) { | |
final Array<LeftEdge> edges = leftEdges[i]; | |
edges.sort(); | |
int edgeCount = 0; | |
int lastRight = UNMATCHED; | |
for (int edgeIndex = 0; edgeIndex < edges.size; edgeIndex++) { | |
final LeftEdge edge = edges.get(edgeIndex); | |
if (edge.right == lastRight) { | |
continue; | |
} | |
lastRight = edge.right; | |
if (edgeIndex != edgeCount) { | |
edges.set(edgeCount, edge); | |
} | |
++edgeCount; | |
} | |
edges.size = edgeCount; | |
} | |
//endregion Edge list initialization | |
// These hold "potentials" for nodes on the left and nodes on the right, which reduce the costs of attached edges. | |
// We maintain that every reduced cost, cost[i][j] - leftPotential[i] - leftPotential[j], is greater than zero. | |
final int[] leftPotential = new int[n]; | |
final int[] rightPotential = new int[n]; | |
//region Node potential initialization | |
// Here, we seek good initial values for the node potentials. | |
// Note: We're guaranteed by the above code that at every node on the left and right has at least one edge. | |
// First, we raise the potentials on the left as high as we can for each node. | |
// This guarantees each node on the left has at least one "tight" edge. | |
for (int i = 0; i < n; i++) { | |
final Array<LeftEdge> edges = leftEdges[i]; | |
int smallestEdgeCost = edges.get(0).cost; | |
for (int edgeIndex = 1; edgeIndex < edges.size; edgeIndex++) { | |
if (edges.get(edgeIndex).cost < smallestEdgeCost) { | |
smallestEdgeCost = edges.get(edgeIndex).cost; | |
} | |
} | |
// Set node potential to the smallest incident edge cost. | |
// This is as high as we can take it without creating an edge with zero reduced cost. | |
leftPotential[i] = smallestEdgeCost; | |
} | |
// Second, we raise the potentials on the right as high as we can for each node. | |
// We do the same as with the left, but this time take into account that costs are reduced | |
// by the left potentials. | |
// This guarantees that each node on the right has at least one "tight" edge. | |
Arrays.fill(rightPotential, oo); | |
for (final WeightedBipartiteEdge edge : allEdges) { | |
int reducedCost = edge.cost - leftPotential[edge.left]; | |
if (rightPotential[edge.right] > reducedCost) { | |
rightPotential[edge.right] = reducedCost; | |
} | |
} | |
//endregion Node potential initialization | |
// Tracks how many edges for each left node are "tight". | |
// Following initialization, we maintain the invariant that these are at the start of the node's edge list. | |
int[] leftTightEdgesCount = new int[n]; | |
//region Tight edge initialization | |
// Here we find all tight edges, defined as edges that have zero reduced cost. | |
// We will be interested in the subgraph induced by the tight edges, so we partition the edge lists for | |
// each left node accordingly, moving the tight edges to the start. | |
for (int i = 0; i < n; i++) { | |
final Array<LeftEdge> edges = leftEdges[i]; | |
int tightEdgeCount = 0; | |
for (int edgeIndex = 0; edgeIndex < edges.size; edgeIndex++) { | |
final LeftEdge edge = edges.get(edgeIndex); | |
int reducedCost = edge.cost - leftPotential[i] - rightPotential[edge.right]; | |
if (reducedCost == 0) { | |
if (edgeIndex != tightEdgeCount) { | |
edges.swap(tightEdgeCount, edgeIndex); | |
} | |
++tightEdgeCount; | |
} | |
} | |
leftTightEdgesCount[i] = tightEdgeCount; | |
} | |
//endregion Tight edge initialization | |
// Now we're ready to begin the inner loop. | |
// We maintain an (initially empty) partial matching, in the subgraph of tight edges. | |
int currentMatchingCardinality = 0; | |
int[] leftMatchedTo = new int[n]; | |
int[] rightMatchedTo = new int[n]; | |
Arrays.fill(leftMatchedTo, UNMATCHED); | |
Arrays.fill(rightMatchedTo, UNMATCHED); | |
//region Initial matching (speedup?) | |
// Because we can, let's make all the trivial matches we can. | |
for (int i = 0; i < n; i++) { | |
final Array<LeftEdge> edges = leftEdges[i]; | |
for (int edgeIndex = 0; edgeIndex < leftTightEdgesCount[i]; edgeIndex++) { | |
int j = edges.get(edgeIndex).right; | |
if (rightMatchedTo[j] == UNMATCHED) { | |
++currentMatchingCardinality; | |
rightMatchedTo[j] = i; | |
leftMatchedTo[i] = j; | |
break; | |
} | |
} | |
} | |
if (currentMatchingCardinality == n) { | |
// Well, that's embarassing. We're already done! | |
return leftMatchedTo; | |
} | |
//endregion Initial matching (speedup?) | |
// While an augmenting path exists, we add it to the matching. | |
// When an augmenting path doesn't exist, we update the potentials so that an edge between the area | |
// we can reach and the unreachable nodes on the right becomes tight, giving us another edge to explore. | |
// | |
// We proceed in this fashion until we can't find more augmenting paths or add edges. | |
// At that point, we either have a min-weight perfect matching, or no matching exists. | |
//region Inner loop state variables | |
// One point of confusion is that we're going to cache the edges between the area we've explored | |
// that are "almost tight", or rather are the closest to being tight. | |
// This is necessary to achieve our O(N^3) runtime. | |
// | |
// rightMinimumSlack[j] gives the smallest amount of "slack" for an unreached node j on the right, | |
// considering the edges between j and some node on the left in our explored area. | |
// | |
// rightMinimumSlackLeftNode[j] gives the node i with the corresponding edge. | |
// rightMinimumSlackEdgeIndex[j] gives the edge index for node i. | |
int[] rightMinimumSlack = new int[n]; | |
int[] rightMinimumSlackLeftNode = new int[n]; | |
int[] rightMinimumSlackEdgeIndex = new int[n]; | |
Queue<Integer> leftNodeQueue = new Queue<>(); | |
boolean[] leftSeen = new boolean[n]; | |
int[] rightBacktrack = new int[n]; | |
// Note: the above are all initialized at the start of the loop. | |
//endregion Inner loop state variables | |
while (currentMatchingCardinality < n) { | |
//region Loop state initialization | |
// Clear out slack caches. | |
// Note: We need to clear the nodes so that we can notice when there aren't any edges available. | |
Arrays.fill(rightMinimumSlack, oo); | |
Arrays.fill(rightMinimumSlackLeftNode, UNMATCHED); | |
// Clear the queue. | |
leftNodeQueue.clear(); | |
// Mark everything "unseen". | |
Arrays.fill(leftSeen, false); | |
Arrays.fill(rightBacktrack, UNMATCHED); | |
//endregion Loop state initialization | |
int startingLeftNode = UNMATCHED; | |
//region Find unmatched starting node | |
// Find an unmatched left node to search outward from. | |
// By heuristic, we pick the node with fewest tight edges, giving the BFS an easier time. | |
// (The asymptotics don't care about this, but maybe it helps. Eh.) | |
{ | |
int minimumTightEdges = oo; | |
for (int i = 0; i < n; i++) { | |
if (leftMatchedTo[i] == UNMATCHED && leftTightEdgesCount[i] < minimumTightEdges) { | |
minimumTightEdges = leftTightEdgesCount[i]; | |
startingLeftNode = i; | |
} | |
} | |
} | |
//endregion Find unmatched starting node | |
assert (startingLeftNode != UNMATCHED); | |
assert leftNodeQueue.isEmpty(); | |
leftNodeQueue.addLast(startingLeftNode); | |
leftSeen[startingLeftNode] = true; | |
int endingRightNode = UNMATCHED; | |
while (endingRightNode == UNMATCHED) { | |
//region BFS until match found or no edges to follow | |
while (endingRightNode == UNMATCHED && leftNodeQueue.notEmpty()) { | |
// Implementation note: this could just as easily be a DFS, but a BFS probably | |
// has less edge flipping (by my guess), so we're using a BFS. | |
final int i = leftNodeQueue.removeFirst(); | |
final Array<LeftEdge> edges = leftEdges[i]; | |
// Note: Some of the edges might not be tight anymore, hence the awful loop. | |
for (int edgeIndex = 0; edgeIndex < leftTightEdgesCount[i]; ++edgeIndex) { | |
final LeftEdge edge = edges.get(edgeIndex); | |
final int j = edge.right; | |
assert (edge.cost - leftPotential[i] - rightPotential[j] >= 0); | |
if (edge.cost > leftPotential[i] + rightPotential[j]) { | |
// This edge is loose now. | |
--leftTightEdgesCount[i]; | |
edges.swap(edgeIndex, leftTightEdgesCount[i]); | |
--edgeIndex; | |
continue; | |
} | |
if (rightBacktrack[j] != UNMATCHED) { | |
continue; | |
} | |
rightBacktrack[j] = i; | |
int matchedTo = rightMatchedTo[j]; | |
if (matchedTo == UNMATCHED) { | |
// Match found. This will terminate the loop. | |
endingRightNode = j; | |
} else if (!leftSeen[matchedTo]) { | |
// No match found, but a new left node is reachable. Track how we got here and extend BFS queue. | |
leftSeen[matchedTo] = true; | |
leftNodeQueue.addLast(matchedTo); | |
} | |
} | |
//region Update cached slack values | |
// The remaining edges may be to nodes that are unreachable. | |
// We accordingly update the minimum slackness for nodes on the right. | |
if (endingRightNode == UNMATCHED) { | |
final int potential = leftPotential[i]; | |
for (int edgeIndex = leftTightEdgesCount[i]; edgeIndex < edges.size; edgeIndex++) { | |
final LeftEdge edge = edges.get(edgeIndex); | |
int j = edge.right; | |
if (rightMatchedTo[j] == UNMATCHED || !leftSeen[rightMatchedTo[j]]) { | |
// This edge is to a node on the right that we haven't reached yet. | |
int reducedCost = edge.cost - potential - rightPotential[j]; | |
assert (reducedCost >= 0); | |
if (reducedCost < rightMinimumSlack[j]) { | |
// There should be a better way to do this backtracking... | |
// One array instead of 3. But I can't think of something else. So it goes. | |
rightMinimumSlack[j] = reducedCost; | |
rightMinimumSlackLeftNode[j] = i; | |
rightMinimumSlackEdgeIndex[j] = edgeIndex; | |
} | |
} | |
} | |
} | |
//endregion Update cached slack values | |
} | |
//endregion BFS until match found or no edges to follow | |
//region Update node potentials to add edges, if no match found | |
if (endingRightNode == UNMATCHED) { | |
// Out of nodes. Time to update some potentials. | |
int minimumSlackRightNode = UNMATCHED; | |
//region Find minimum slack node, or abort if none exists | |
int minimumSlack = oo; | |
for (int j = 0; j < n; j++) { | |
if (rightMatchedTo[j] == UNMATCHED || !leftSeen[rightMatchedTo[j]]) { | |
// This isn't a node reached by our BFS. Update minimum slack. | |
if (rightMinimumSlack[j] < minimumSlack) { | |
minimumSlack = rightMinimumSlack[j]; | |
minimumSlackRightNode = j; | |
} | |
} | |
} | |
if (minimumSlackRightNode == UNMATCHED || rightMinimumSlackLeftNode[minimumSlackRightNode] == UNMATCHED) { | |
// The caches are all empty. There was no option available. | |
// This means that the node the BFS started at, which is an unmatched left node, cannot reach the | |
// right - i.e. it will be impossible to find a perfect matching. | |
return null; | |
} | |
//endregion Find minimum slack node, or abort if none exists | |
assert minimumSlackRightNode != UNMATCHED; | |
// Adjust potentials on left and right. | |
for (int i = 0; i < n; i++) { | |
if (leftSeen[i]) { | |
leftPotential[i] += minimumSlack; | |
if (leftMatchedTo[i] != UNMATCHED) { | |
rightPotential[leftMatchedTo[i]] -= minimumSlack; | |
} | |
} | |
} | |
// Downward-adjust slackness caches. | |
for (int j = 0; j < n; j++) { | |
if (rightMatchedTo[j] == UNMATCHED || !leftSeen[rightMatchedTo[j]]) { | |
rightMinimumSlack[j] -= minimumSlack; | |
// If the slack hit zero, then we just found ourselves a new tight edge. | |
if (rightMinimumSlack[j] == 0) { | |
final int i = rightMinimumSlackLeftNode[j]; | |
final int edgeIndex = rightMinimumSlackEdgeIndex[j]; | |
//region Update leftEdges[i] and leftTightEdgesCount[i] | |
// Move it in the relevant edge list. | |
if (edgeIndex != leftTightEdgesCount[i]) { | |
final Array<LeftEdge> edges = leftEdges[i]; | |
edges.swap(edgeIndex, leftTightEdgesCount[i]); | |
} | |
++leftTightEdgesCount[i]; | |
//endregion Update leftEdges[i] and leftTightEdgesCount[i] | |
// If we haven't already encountered a match, we follow the edge and update the BFS queue. | |
// It's possible this edge leads to a match. If so, we'll carry on updating the tight edges, | |
// but won't follow them. | |
if (endingRightNode == UNMATCHED) { | |
// We're contemplating the consequences of following (i, j), as we do in the BFS above. | |
rightBacktrack[j] = i; | |
int matchedTo = rightMatchedTo[j]; | |
if (matchedTo == UNMATCHED) { | |
// Match found! | |
endingRightNode = j; | |
} else if (!leftSeen[matchedTo]) { | |
// No match, but new left node found. Extend BFS queue. | |
leftSeen[matchedTo] = true; | |
leftNodeQueue.addLast(matchedTo); | |
} | |
} | |
} | |
} | |
} | |
} | |
//endregion Update node potentials to add edges, if no match found | |
} | |
// At this point, we've found an augmenting path between startingLeftNode and endingRightNode. | |
// We'll just use the backtracking info to update our match information. | |
++currentMatchingCardinality; | |
// Backtrack and flip augmenting path | |
int currentRightNode = endingRightNode; | |
while (currentRightNode != UNMATCHED) { | |
final int currentLeftNode = rightBacktrack[currentRightNode]; | |
final int nextRightNode = leftMatchedTo[currentLeftNode]; | |
rightMatchedTo[currentRightNode] = currentLeftNode; | |
leftMatchedTo[currentLeftNode] = currentRightNode; | |
currentRightNode = nextRightNode; | |
} | |
} | |
return leftMatchedTo; | |
} | |
/** | |
* Specialized version of {@link #hungarianMinimumWeightPerfectMatching(int, WeightedBipartiteEdge[])} for complete bipartite graphs. | |
* | |
* @param allEdges all defined edges in n*n format, where [a][b] = c (a = left index, b = right index, cost) | |
* @return null if no matching exists or array where [a] == b where a is index to left node and b index to right node | |
*/ | |
public static int[] hungarianMinimumWeightPerfectMatching(final int[][] allEdges) { | |
final int n = allEdges.length; | |
// Edge lists for each left node. | |
final LeftEdge[][] leftEdges = new LeftEdge[n][n]; | |
for (int left = 0; left < n; left++) { | |
final int[] leftCosts = allEdges[left]; | |
final LeftEdge[] lefts = leftEdges[left]; | |
for (int right = 0; right < n; right++) { | |
lefts[right] = new LeftEdge(right, leftCosts[right]); | |
} | |
} | |
// These hold "potentials" for nodes on the left and nodes on the right, which reduce the costs of attached edges. | |
// We maintain that every reduced cost, cost[i][j] - leftPotential[i] - leftPotential[j], is greater than zero. | |
final int[] leftPotential = new int[n]; | |
final int[] rightPotential = new int[n]; | |
//region Node potential initialization | |
// Here, we seek good initial values for the node potentials. | |
// Note: We're guaranteed by the above code that at every node on the left and right has at least one edge. | |
// First, we raise the potentials on the left as high as we can for each node. | |
// This guarantees each node on the left has at least one "tight" edge. | |
for (int i = 0; i < n; i++) { | |
// Set node potential to the smallest incident edge cost. | |
// This is as high as we can take it without creating an edge with zero reduced cost. | |
leftPotential[i] = min(allEdges[i]); | |
} | |
// Second, we raise the potentials on the right as high as we can for each node. | |
// We do the same as with the left, but this time take into account that costs are reduced | |
// by the left potentials. | |
// This guarantees that each node on the right has at least one "tight" edge. | |
Arrays.fill(rightPotential, oo); | |
for (int left = 0; left < n; left++) { | |
final int leftPotentialCost = leftPotential[left]; | |
for (int right = 0; right < n; right++) { | |
final int reducedCost = allEdges[left][right] - leftPotentialCost; | |
rightPotential[right] = Math.min(rightPotential[right], reducedCost); | |
} | |
} | |
//endregion Node potential initialization | |
// Tracks how many edges for each left node are "tight". | |
// Following initialization, we maintain the invariant that these are at the start of the node's edge list. | |
int[] leftTightEdgesCount = new int[n]; | |
//region Tight edge initialization | |
// Here we find all tight edges, defined as edges that have zero reduced cost. | |
// We will be interested in the subgraph induced by the tight edges, so we partition the edge lists for | |
// each left node accordingly, moving the tight edges to the start. | |
for (int i = 0; i < n; i++) { | |
final LeftEdge[] edges = leftEdges[i]; | |
int tightEdgeCount = 0; | |
for (int edgeIndex = 0; edgeIndex < edges.length; edgeIndex++) { | |
final LeftEdge edge = edges[edgeIndex]; | |
int reducedCost = edge.cost - leftPotential[i] - rightPotential[edge.right]; | |
if (reducedCost == 0) { | |
if (edgeIndex != tightEdgeCount) { | |
swap(edges, tightEdgeCount, edgeIndex); | |
} | |
++tightEdgeCount; | |
} | |
} | |
leftTightEdgesCount[i] = tightEdgeCount; | |
} | |
//endregion Tight edge initialization | |
// Now we're ready to begin the inner loop. | |
// We maintain an (initially empty) partial matching, in the subgraph of tight edges. | |
int currentMatchingCardinality = 0; | |
int[] leftMatchedTo = new int[n]; | |
int[] rightMatchedTo = new int[n]; | |
Arrays.fill(leftMatchedTo, UNMATCHED); | |
Arrays.fill(rightMatchedTo, UNMATCHED); | |
//region Initial matching (speedup?) | |
// Because we can, let's make all the trivial matches we can. | |
for (int i = 0; i < n; i++) { | |
final LeftEdge[] edges = leftEdges[i]; | |
for (int edgeIndex = 0; edgeIndex < leftTightEdgesCount[i]; edgeIndex++) { | |
int j = edges[edgeIndex].right; | |
if (rightMatchedTo[j] == UNMATCHED) { | |
++currentMatchingCardinality; | |
rightMatchedTo[j] = i; | |
leftMatchedTo[i] = j; | |
break; | |
} | |
} | |
} | |
if (currentMatchingCardinality == n) { | |
// Well, that's embarassing. We're already done! | |
return leftMatchedTo; | |
} | |
//endregion Initial matching (speedup?) | |
// While an augmenting path exists, we add it to the matching. | |
// When an augmenting path doesn't exist, we update the potentials so that an edge between the area | |
// we can reach and the unreachable nodes on the right becomes tight, giving us another edge to explore. | |
// | |
// We proceed in this fashion until we can't find more augmenting paths or add edges. | |
// At that point, we either have a min-weight perfect matching, or no matching exists. | |
//region Inner loop state variables | |
// One point of confusion is that we're going to cache the edges between the area we've explored | |
// that are "almost tight", or rather are the closest to being tight. | |
// This is necessary to achieve our O(N^3) runtime. | |
// | |
// rightMinimumSlack[j] gives the smallest amount of "slack" for an unreached node j on the right, | |
// considering the edges between j and some node on the left in our explored area. | |
// | |
// rightMinimumSlackLeftNode[j] gives the node i with the corresponding edge. | |
// rightMinimumSlackEdgeIndex[j] gives the edge index for node i. | |
int[] rightMinimumSlack = new int[n]; | |
int[] rightMinimumSlackLeftNode = new int[n]; | |
int[] rightMinimumSlackEdgeIndex = new int[n]; | |
Queue<Integer> leftNodeQueue = new Queue<>(); | |
boolean[] leftSeen = new boolean[n]; | |
int[] rightBacktrack = new int[n]; | |
// Note: the above are all initialized at the start of the loop. | |
//endregion Inner loop state variables | |
while (currentMatchingCardinality < n) { | |
//region Loop state initialization | |
// Clear out slack caches. | |
// Note: We need to clear the nodes so that we can notice when there aren't any edges available. | |
Arrays.fill(rightMinimumSlack, oo); | |
Arrays.fill(rightMinimumSlackLeftNode, UNMATCHED); | |
// Clear the queue. | |
leftNodeQueue.clear(); | |
// Mark everything "unseen". | |
Arrays.fill(leftSeen, false); | |
Arrays.fill(rightBacktrack, UNMATCHED); | |
//endregion Loop state initialization | |
int startingLeftNode = UNMATCHED; | |
//region Find unmatched starting node | |
// Find an unmatched left node to search outward from. | |
// By heuristic, we pick the node with fewest tight edges, giving the BFS an easier time. | |
// (The asymptotics don't care about this, but maybe it helps. Eh.) | |
{ | |
int minimumTightEdges = oo; | |
for (int i = 0; i < n; i++) { | |
if (leftMatchedTo[i] == UNMATCHED && leftTightEdgesCount[i] < minimumTightEdges) { | |
minimumTightEdges = leftTightEdgesCount[i]; | |
startingLeftNode = i; | |
} | |
} | |
} | |
//endregion Find unmatched starting node | |
assert (startingLeftNode != UNMATCHED); | |
assert leftNodeQueue.isEmpty(); | |
leftNodeQueue.addLast(startingLeftNode); | |
leftSeen[startingLeftNode] = true; | |
int endingRightNode = UNMATCHED; | |
while (endingRightNode == UNMATCHED) { | |
//region BFS until match found or no edges to follow | |
while (endingRightNode == UNMATCHED && leftNodeQueue.notEmpty()) { | |
// Implementation note: this could just as easily be a DFS, but a BFS probably | |
// has less edge flipping (by my guess), so we're using a BFS. | |
final int i = leftNodeQueue.removeFirst(); | |
final LeftEdge[] edges = leftEdges[i]; | |
// Note: Some of the edges might not be tight anymore, hence the awful loop. | |
for (int edgeIndex = 0; edgeIndex < leftTightEdgesCount[i]; ++edgeIndex) { | |
final LeftEdge edge = edges[edgeIndex]; | |
final int j = edge.right; | |
assert (edge.cost - leftPotential[i] - rightPotential[j] >= 0); | |
if (edge.cost > leftPotential[i] + rightPotential[j]) { | |
// This edge is loose now. | |
--leftTightEdgesCount[i]; | |
swap(edges, edgeIndex, leftTightEdgesCount[i]); | |
--edgeIndex; | |
continue; | |
} | |
if (rightBacktrack[j] != UNMATCHED) { | |
continue; | |
} | |
rightBacktrack[j] = i; | |
int matchedTo = rightMatchedTo[j]; | |
if (matchedTo == UNMATCHED) { | |
// Match found. This will terminate the loop. | |
endingRightNode = j; | |
} else if (!leftSeen[matchedTo]) { | |
// No match found, but a new left node is reachable. Track how we got here and extend BFS queue. | |
leftSeen[matchedTo] = true; | |
leftNodeQueue.addLast(matchedTo); | |
} | |
} | |
//region Update cached slack values | |
// The remaining edges may be to nodes that are unreachable. | |
// We accordingly update the minimum slackness for nodes on the right. | |
if (endingRightNode == UNMATCHED) { | |
final int potential = leftPotential[i]; | |
for (int edgeIndex = leftTightEdgesCount[i]; edgeIndex < edges.length; edgeIndex++) { | |
final LeftEdge edge = edges[edgeIndex]; | |
int j = edge.right; | |
if (rightMatchedTo[j] == UNMATCHED || !leftSeen[rightMatchedTo[j]]) { | |
// This edge is to a node on the right that we haven't reached yet. | |
int reducedCost = edge.cost - potential - rightPotential[j]; | |
assert (reducedCost >= 0); | |
if (reducedCost < rightMinimumSlack[j]) { | |
// There should be a better way to do this backtracking... | |
// One array instead of 3. But I can't think of something else. So it goes. | |
rightMinimumSlack[j] = reducedCost; | |
rightMinimumSlackLeftNode[j] = i; | |
rightMinimumSlackEdgeIndex[j] = edgeIndex; | |
} | |
} | |
} | |
} | |
//endregion Update cached slack values | |
} | |
//endregion BFS until match found or no edges to follow | |
//region Update node potentials to add edges, if no match found | |
if (endingRightNode == UNMATCHED) { | |
// Out of nodes. Time to update some potentials. | |
int minimumSlackRightNode = UNMATCHED; | |
//region Find minimum slack node, or abort if none exists | |
int minimumSlack = oo; | |
for (int j = 0; j < n; j++) { | |
if (rightMatchedTo[j] == UNMATCHED || !leftSeen[rightMatchedTo[j]]) { | |
// This isn't a node reached by our BFS. Update minimum slack. | |
if (rightMinimumSlack[j] < minimumSlack) { | |
minimumSlack = rightMinimumSlack[j]; | |
minimumSlackRightNode = j; | |
} | |
} | |
} | |
if (minimumSlackRightNode == UNMATCHED || rightMinimumSlackLeftNode[minimumSlackRightNode] == UNMATCHED) { | |
// The caches are all empty. There was no option available. | |
// This means that the node the BFS started at, which is an unmatched left node, cannot reach the | |
// right - i.e. it will be impossible to find a perfect matching. | |
return null; | |
} | |
//endregion Find minimum slack node, or abort if none exists | |
assert minimumSlackRightNode != UNMATCHED; | |
// Adjust potentials on left and right. | |
for (int i = 0; i < n; i++) { | |
if (leftSeen[i]) { | |
leftPotential[i] += minimumSlack; | |
if (leftMatchedTo[i] != UNMATCHED) { | |
rightPotential[leftMatchedTo[i]] -= minimumSlack; | |
} | |
} | |
} | |
// Downward-adjust slackness caches. | |
for (int j = 0; j < n; j++) { | |
if (rightMatchedTo[j] == UNMATCHED || !leftSeen[rightMatchedTo[j]]) { | |
rightMinimumSlack[j] -= minimumSlack; | |
// If the slack hit zero, then we just found ourselves a new tight edge. | |
if (rightMinimumSlack[j] == 0) { | |
final int i = rightMinimumSlackLeftNode[j]; | |
final int edgeIndex = rightMinimumSlackEdgeIndex[j]; | |
//region Update leftEdges[i] and leftTightEdgesCount[i] | |
// Move it in the relevant edge list. | |
if (edgeIndex != leftTightEdgesCount[i]) { | |
final LeftEdge[] edges = leftEdges[i]; | |
swap(edges, edgeIndex, leftTightEdgesCount[i]); | |
} | |
++leftTightEdgesCount[i]; | |
//endregion Update leftEdges[i] and leftTightEdgesCount[i] | |
// If we haven't already encountered a match, we follow the edge and update the BFS queue. | |
// It's possible this edge leads to a match. If so, we'll carry on updating the tight edges, | |
// but won't follow them. | |
if (endingRightNode == UNMATCHED) { | |
// We're contemplating the consequences of following (i, j), as we do in the BFS above. | |
rightBacktrack[j] = i; | |
int matchedTo = rightMatchedTo[j]; | |
if (matchedTo == UNMATCHED) { | |
// Match found! | |
endingRightNode = j; | |
} else if (!leftSeen[matchedTo]) { | |
// No match, but new left node found. Extend BFS queue. | |
leftSeen[matchedTo] = true; | |
leftNodeQueue.addLast(matchedTo); | |
} | |
} | |
} | |
} | |
} | |
} | |
//endregion Update node potentials to add edges, if no match found | |
} | |
// At this point, we've found an augmenting path between startingLeftNode and endingRightNode. | |
// We'll just use the backtracking info to update our match information. | |
++currentMatchingCardinality; | |
// Backtrack and flip augmenting path | |
int currentRightNode = endingRightNode; | |
while (currentRightNode != UNMATCHED) { | |
final int currentLeftNode = rightBacktrack[currentRightNode]; | |
final int nextRightNode = leftMatchedTo[currentLeftNode]; | |
rightMatchedTo[currentRightNode] = currentLeftNode; | |
leftMatchedTo[currentLeftNode] = currentRightNode; | |
currentRightNode = nextRightNode; | |
} | |
} | |
return leftMatchedTo; | |
} | |
private static int min(int[] of) { | |
int min = of[0]; | |
for (int i = 1; i < of.length; i++) { | |
min = Math.min(min, of[i]); | |
} | |
return min; | |
} | |
private static <T> void swap(T[] array, int a, int b) { | |
T tmp = array[a]; | |
array[a] = array[b]; | |
array[b] = tmp; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment