Skip to content

Instantly share code, notes, and snippets.

@py-in-the-sky
Last active January 29, 2018 22:19
Show Gist options
  • Save py-in-the-sky/dbbbf282fb00d417d911416a2497bbe8 to your computer and use it in GitHub Desktop.
Save py-in-the-sky/dbbbf282fb00d417d911416a2497bbe8 to your computer and use it in GitHub Desktop.
Knapsack Branch-and-bound
"""
My implementation of a branch-and-bound solution to the Knapsack problem.
See section 3.8 of the book Computer Science Distilled by Wladston Ferreira Filho.
"""
from my_abstract_data_types import *
def knapsack_branch_and_bound(items, weight_limit):
sorted_items = Items(sorted(items, key=lambda item: item.value_to_weight_ratio, reverse=True))
lower = lambda *args: knapsack_greedy_whole_items(*args).value
upper = knapsack_greedy_fractional_items
initial_state = SearchState(float('-inf'), float('inf'), sorted_items, frozenset(), weight_limit, 0)
priority_queue = Heap([initial_state],
key=lambda state: (-state.upper_bound, -state.lower_bound, len(state.remaining_items)))
while priority_queue:
# Questions:
# * Will this process end? That is, will we always eventuall get `state.upper_bound == state.lower_bound`?
# Yes! Eventually, remaining_items.rest() will be empty and we'll get right_upper == right_lower == right_value below.
state = priority_queue.pop()
if state.upper_bound == state.lower_bound:
solution_items = state.packed_items | knapsack_greedy_whole_items(state.remaining_items, state.weight_limit).items
return knapsack_solution(solution_items)
else:
assert not state.remaining_items.empty()
first = state.remaining_items.first()
rest = state.remaining_items.rest()
# branch "left": take first item
left_items = state.packed_items | {first}
left_weight = state.weight_limit - first.weight
left_value = state.value + first.value
left_upper = upper(rest, left_weight) + left_value
left_lower = lower(rest, left_weight) + left_value
left_state = SearchState(left_upper, left_lower, rest, left_items, left_weight, left_value)
if left_state.weight_limit >= 0:
priority_queue.push(left_state)
# branch "right": throw away first item
right_items = state.packed_items
right_weight = state.weight_limit
right_value = state.value
right_upper = upper(rest, right_weight) + right_value
right_lower = lower(rest, right_weight) + right_value
right_state = SearchState(right_upper, right_lower, rest, right_items, right_weight, right_value)
priority_queue.push(right_state)
def knapsack_greedy_whole_items(items, weight_limit):
sorted_items_list = sorted(items, key=lambda item: item.value_to_weight_ratio, reverse=True)
collected_items = set()
for item in sorted_items_list:
if item.weight <= weight_limit:
weight_limit -= item.weight
collected_items.add(item)
return KnapsackSolution(frozenset(collected_items), sum(item.value for item in collected_items))
def knapsack_greedy_fractional_items(items, weight_limit):
sorted_items_list = sorted(items, key=lambda item: item.value_to_weight_ratio, reverse=True)
value = 0
for item in sorted_items_list:
if item.weight <= weight_limit:
weight_limit -= item.weight
value += item.value
else:
fraction = float(weight_limit) / item.weight
value = fraction * item.value
weight_limit = 0
return value
from collections import namedtuple
Item = namedtuple('Item', 'value weight value_to_weight_ratio')
KnapsackSolution = namedtuple('KnapsackSolution', 'items value')
SearchState = namedtuple('SearchState', 'upper_bound lower_bound remaining_items packed_items weight_limit value')
def knapsack_solution(items):
return KnapsackSolution(items, sum(item.value for item in items))
def item(value, weight):
assert weight > 0
return Item(value, weight, float(value)/weight)
class Items:
def __init__(self, items, i=0):
self._items = items
self._i = i
def empty(self):
return self._i >= len(self._items)
def first(self):
return self._items[self._i]
def rest(self):
return Items(self._items, self._i + 1)
def __iter__(self):
return (self._items[i] for i in xrange(self._i, len(self._items)))
def __len__(self):
return 0 if self.empty() else (len(self._items) - self._i)
class Heap:
# "heap property": a binary tree has the heap property established if
# the root node has no descendants or is smaller than or equal to all
# descendants and the left and right subtrees also have the heap property
# established.
# For all big-oh notations below, N = len(self._heap).
def __init__(self, elements, key=lambda x: x):
self._key = key
self._heap = elements[:] # heap property not necessarily established
self._heapify() # heap property established
def __bool__(self):
return len(self._heap) > 0
def __repr__(self):
return repr(self._heap)
def push(self, element):
# precondition: heap property exists
self._heap.append(element) # heap property possibly violated; O(1) time
self._bubble_up() # heap property restored; O(logN) time
def pop(self):
# precondition: heap property exists
assert len(self._heap) > 0
if len(self._heap) == 1:
return self._heap.pop() # heap property maintained; O(1) time
else:
head_element = self._heap[0]
self._heap[0] = self._heap.pop() # heap property possibly violated; O(1) time
self._bubble_down(0) # heap property restored; O(logN) time
return head_element
def _bubble_up(self):
# O(logN) time
assert len(self._heap) > 0
parent_index = lambda i: (i + 1) // 2 - 1
parent = lambda i: self._heap[parent_index(i)]
parent_key = lambda i: self._key(parent(i))
element_index = len(self._heap) - 1
element = self._heap[element_index]
element_key = self._key(element)
while element_index != 0 and element_key < parent_key(element_index):
self._heap[element_index] = parent(element_index)
element_index = parent_index(element_index)
self._heap[element_index] = element
def _bubble_down(self, element_index):
# O(logN) time
# precondition: the left and right subtrees of element have the heap property
assert 0 <= element_index < len(self._heap)
element = self._heap[element_index]
element_key = self._key(element)
child_indices = lambda i: filter(lambda j: j < len(self._heap), (i * 2 + 1, i * 2 + 2))
has_children = lambda i: len(child_indices(i)) > 0
index_of_smaller_child = lambda i: min(child_indices(i), key=lambda j: self._key(self._heap[j]))
smaller_child = lambda i: self._heap[index_of_smaller_child(i)]
key_of_smaller_child = lambda i: self._key(smaller_child(i))
while has_children(element_index) and element_key > key_of_smaller_child(element_index):
self._heap[element_index] = smaller_child(element_index)
element_index = index_of_smaller_child(element_index)
self._heap[element_index] = element
def _heapify(self):
# O(N) time. Why? The analysis is subtle. See the following:
# https://www.cs.umd.edu/~meesh/351/mount/lectures/lect14-heapsort-analysis-part.pdf
# Here's another interesting perspective to check out:
# https://math.stackexchange.com/questions/181022/worst-case-analysis-of-max-heapify-procedure
n = len(self._heap)
indices_from_bottom_to_top = reversed(xrange(n))
for i in indices_from_bottom_to_top:
self._bubble_down(i)
# Note: for simplicity, we iterate over the indices [n-1, n-2, ..., 1, 0].
# However, there's no need to heapify any leaf node (e.g., node n-1) since they
# all already have the heap property established. Therefore, we could introduce a
# small optimization by setting setting n above to one plus the largest index for a
# non-leaf node in self._heap. What is the largest index for a non-leaf node? The last
# non-leaf node must be the parent of the very last element in the heap, which has an
# index of n-1. The parent index of n-1 is (((n-1) + 1) // 2 - 1) = n // 2 - 1.
# Therefore, we could get the optimization above by setting n = 1 + (n // 2 - 1) = n // 2,
# in which case we would just iterate over [n//2 - 1, n//2 - 2, ..., 1, 0], saving some time.
# Either way, the algorithm is correct and has the same asymptotic runtime.
def tests():
# test cases TODO:
# https://lagunita.stanford.edu/courses/course-v1:Engineering+Algorithms2+SelfPaced/courseware/adfcef80aaab466aafedce7fa4b23495/24de6a5b1a94465fb517616fee39abd1/
def pack_items(value_weight_pairs):
return map(lambda args: item(*args), value_weight_pairs)
# from http://www.spoj.com/problems/KNAPSACK/
weight_limit = 4
items = pack_items([
(8, 1), # (value, weight)
(4, 2),
(0, 3),
(5, 2),
(3, 2)
])
expected_value = 13
assert knapsack_branch_and_bound(items, weight_limit).value == expected_value
# from https://www.geeksforgeeks.org/knapsack-problem/
weight_limit = 50
items = pack_items([
(60, 10),
(100, 20),
(120, 30)
])
expected_value = 220
assert knapsack_branch_and_bound(items, weight_limit).value == expected_value
print 'Tests pass!'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment