Skip to content

Instantly share code, notes, and snippets.

@blogle
Created December 4, 2014 23:02

Revisions

  1. blogle created this gist Dec 4, 2014.
    269 changes: 269 additions & 0 deletions gistfile1.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,269 @@
    import numpy as np
    import networkx as nx
    import heapq
    from collections import defaultdict

    def distance(u, v):
    return np.sum((u - v)**2)

    class KDTree(object):

    def __init__(self, data, index=None, depth=0):
    """
    Creates a recursive space partitioning DataStructure where each node
    splits the dimension at the median of that axis. Similar to a BST,
    provides O(n log n) creation and O(log n) queries.
    Args:
    data (np.array): dataset with shape n, k (n obs, k dim).
    Optional
    index (np.array): Index corresponding to each node in data
    if left empty, the data is zero indexed.
    depth (int) : This determines the axis in which to first
    partition on e.g 0 -> x, 1 -> y, 2 -> z
    Notes:
    http://en.wikipedia.org/wiki/K-d_tree
    http://en.wikipedia.org/wiki/K-d_tree#mediaviewer/File:KDTree-animation.gif
    """
    # Build index at top level
    if type(index) == type(None):
    index = np.arange(data.shape[0])

    self.n = None
    self.k = None
    self.idx = None
    self.node = None
    self.axis = None
    self.left = None
    self.right = None
    self.children = None

    self._build(data, index, depth)

    def _build(self, data, index, depth):
    """Recursively builds the child nodes of the KDTree"""
    # If there is data to partition create nodes
    if data[index].size:

    # Store the dimensions of the data and the axis to partition on
    self.n, self.k = data[index].shape
    self.axis = (self.k + depth) % self.k

    # list of nodes beneath this node
    self.children = index

    # Find the index of the data sorted on the current axis
    # and the midpoint in which to partition
    idx_data = np.column_stack((data[index], index))
    sort_ax = idx_data[np.argsort(idx_data[:, self.axis]), -1].astype(int)
    partition = sort_ax.size / 2

    # Node index and data
    self.idx = sort_ax[partition]
    self.node = data[self.idx]

    # Build the branches, partitioning on the next axis
    self.left = KDTree(data, sort_ax[ : partition], depth+1)
    self.right = KDTree(data, sort_ax[partition+1:], depth+1)

    def near_branch(self, point):
    """Returns the branch nearest the input point"""
    if point[self.axis] < self.node[self.axis]:
    return self.left
    return self.right

    def far_branch(self, point):
    """Returns the branch furthest the input point"""
    if self.near_branch(point) == self.left:
    return self.right
    return self.left

    def orthogonal_dist(self, point):
    """computes the distance from a point to the partition"""
    orth_point = np.copy(point)
    orth_point[self.axis] = self.node[self.axis]
    return distance(point, self.node)

    def query(self, point, best=None):
    """Find the nearest neighbor of point in KDTree"""

    # Dead end backtrack up the tree
    if self.node is None:
    return best

    # Initialize best
    if best is None:
    best = (self.idx, self.node)

    # check if current node is closer than best
    if distance(self.node, point) < distance(best[1], point):
    best = (self.idx, self.node)

    # continue traversing the tree
    best = self.near_branch(point).query(point, best)

    # traverse the away branch if the orthogonal distance is less than best
    if self.orthogonal_dist(point) < distance(best[1], point):
    best = self.far_branch(point).query(point, best)
    return best

    def query_subset(self, point, subset):
    """Find the nearest neighbor of point in subset"""
    subset_vec = np.zeros(self.n)
    subset_vec[subset] = 1

    return self._query_subset(point, subset_vec, None)

    def _query_subset(self, point, subset, best=None):
    """Recursively implements constrained nearest neighbor search"""

    # Dead end backtrack up the tree
    if np.all(self.node == None):
    return best

    # Initialize node vectors
    idx_vec = np.empty_like(subset)
    child_vec = np.empty_like(subset)
    idx_vec[:] = child_vec[:] = 0
    idx_vec[self.idx] = child_vec[self.children] = 1

    # if point in subset, try to update best
    if np.dot(idx_vec, subset) != 0:
    # if closer than current best, or best is none update
    # is_closer is a thunk to prevent '__getitem__' error
    is_closer = lambda: distance(self.node, point) < distance(best[1], point)
    if np.all(best == None) or is_closer():
    best = (self.idx, self.node)

    near = self.near_branch(point)
    far = self.far_branch(point)

    # check the near branch, if its nodes intersect with the queried subset
    # otherwise move to the away branch
    if np.dot(child_vec, subset) > 0:
    best = near._query_subset(point, subset, best)
    else:
    best = far._query_subset(point, subset, best)

    # validate best, by ensuring closer point doesn't exist just beyond
    # partition if best still has yet to be found also look
    # into this further branch
    if (np.all(best != None) and self.orthogonal_dist(point) <
    distance(best[1], point)) or np.all(best == None):
    best = far._query_subset(point, subset, best)

    return best

    class PriorityQueue(object):

    def __init__(self):
    """
    Queue implementing highest-priority-in first-out.
    Note:
    Priority is cost based, therefore smaller values are prioritized
    over larger values.
    """
    self._queue = []
    self._index = 0

    def push(self, item, priority):
    """
    Push an item into the queue.
    Args:
    item (obj): Item to be stored in the queue
    priority (Num): Priority in which item will be retrieved from the queue
    """
    heapq.heappush(self._queue, (priority, self._index, item))
    self._index += 1

    def pop(self):
    """
    Removes the highest priority item from the queue
    Returns:
    obj: item with highest priority
    """
    return heapq.heappop(self._queue)[-1]

    def merge(self, other):
    """
    Given another queue, consumes each item in it
    and pushes the item and its priority into its own queue
    Args:
    other (PriorityQueue): Queue to be merged
    """
    while other._queue:
    priority,i,item = heapq.heappop(other._queue)
    self.push(item, priority)

    def top(self):
    """
    Allows peek at top item in the queue without removing it
    Returns:
    obj: if the queue is not empty otherwise None
    """
    try:
    return self._queue[0][-1]
    except:
    return None

    def bvka_mst_edges(G, assume_connected=False, pos='coords'):

    V = set(G.nodes(data=False))
    pos = np.row_stack(nx.get_node_attributes(G, pos).values())

    kdtree = KDTree(pos)
    subgraphs = nx.utils.UnionFind()

    # This could be swapped for a defaultdict if preferred
    queues = defaultdict(PriorityQueue)

    for v in V:
    # Todo restrict this further to connected edges
    vm, _ = kdtree.query_subset(pos[v], list(V - {v}))
    dm = distance(pos[v], pos[vm])
    root = subgraphs[v]
    queues[root].push((v, vm), dm)

    Et = []
    while len(Et) != len(V) - 1:

    Ep = PriorityQueue()
    for C in set(map(subgraphs.__getitem__, subgraphs.parents.values())):

    (v, vm) = queues[C].top()
    component_set = [child for child, parent
    in subgraphs.parents.iteritems()
    if parent == C]
    disjoint_nodes = list(V - set(component_set))

    while vm in component_set:
    queues[C].pop()
    um, _ = kdtree.query_subset(pos[v], disjoint_nodes)
    dm = distance(pos[v], pos[vm])
    queues[C].push((v, um), dm)
    (v, vm) = queues[C].top()

    dm = distance(pos[v], pos[vm])
    Ep.push((v, vm, dm), dm)

    while Ep._queue:
    (um, vm, dm) = Ep.pop()

    component_i, component_j = subgraphs[um], subgraphs[vm]
    if component_i != component_j:
    # add the edge and merge the queues
    Et += [(um, vm)]

    subgraphs.union(um, vm)
    if component_i == subgraphs[um]:
    major, minor = component_i, component_j
    else:
    minor, major = component_i, component_j

    queues[major].merge(queues[minor])
    del(queues[minor])
    return Et