Last active
January 29, 2018 22:19
-
-
Save py-in-the-sky/dbbbf282fb00d417d911416a2497bbe8 to your computer and use it in GitHub Desktop.
Knapsack Branch-and-bound
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
My implementation of a branch-and-bound solution to the Knapsack problem. | |
See section 3.8 of the book Computer Science Distilled by Wladston Ferreira Filho. | |
""" | |
from my_abstract_data_types import * | |
def knapsack_branch_and_bound(items, weight_limit): | |
sorted_items = Items(sorted(items, key=lambda item: item.value_to_weight_ratio, reverse=True)) | |
lower = lambda *args: knapsack_greedy_whole_items(*args).value | |
upper = knapsack_greedy_fractional_items | |
initial_state = SearchState(float('-inf'), float('inf'), sorted_items, frozenset(), weight_limit, 0) | |
priority_queue = Heap([initial_state], | |
key=lambda state: (-state.upper_bound, -state.lower_bound, len(state.remaining_items))) | |
while priority_queue: | |
# Questions: | |
# * Will this process end? That is, will we always eventuall get `state.upper_bound == state.lower_bound`? | |
# Yes! Eventually, remaining_items.rest() will be empty and we'll get right_upper == right_lower == right_value below. | |
state = priority_queue.pop() | |
if state.upper_bound == state.lower_bound: | |
solution_items = state.packed_items | knapsack_greedy_whole_items(state.remaining_items, state.weight_limit).items | |
return knapsack_solution(solution_items) | |
else: | |
assert not state.remaining_items.empty() | |
first = state.remaining_items.first() | |
rest = state.remaining_items.rest() | |
# branch "left": take first item | |
left_items = state.packed_items | {first} | |
left_weight = state.weight_limit - first.weight | |
left_value = state.value + first.value | |
left_upper = upper(rest, left_weight) + left_value | |
left_lower = lower(rest, left_weight) + left_value | |
left_state = SearchState(left_upper, left_lower, rest, left_items, left_weight, left_value) | |
if left_state.weight_limit >= 0: | |
priority_queue.push(left_state) | |
# branch "right": throw away first item | |
right_items = state.packed_items | |
right_weight = state.weight_limit | |
right_value = state.value | |
right_upper = upper(rest, right_weight) + right_value | |
right_lower = lower(rest, right_weight) + right_value | |
right_state = SearchState(right_upper, right_lower, rest, right_items, right_weight, right_value) | |
priority_queue.push(right_state) | |
def knapsack_greedy_whole_items(items, weight_limit): | |
sorted_items_list = sorted(items, key=lambda item: item.value_to_weight_ratio, reverse=True) | |
collected_items = set() | |
for item in sorted_items_list: | |
if item.weight <= weight_limit: | |
weight_limit -= item.weight | |
collected_items.add(item) | |
return KnapsackSolution(frozenset(collected_items), sum(item.value for item in collected_items)) | |
def knapsack_greedy_fractional_items(items, weight_limit): | |
sorted_items_list = sorted(items, key=lambda item: item.value_to_weight_ratio, reverse=True) | |
value = 0 | |
for item in sorted_items_list: | |
if item.weight <= weight_limit: | |
weight_limit -= item.weight | |
value += item.value | |
else: | |
fraction = float(weight_limit) / item.weight | |
value = fraction * item.value | |
weight_limit = 0 | |
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
Item = namedtuple('Item', 'value weight value_to_weight_ratio') | |
KnapsackSolution = namedtuple('KnapsackSolution', 'items value') | |
SearchState = namedtuple('SearchState', 'upper_bound lower_bound remaining_items packed_items weight_limit value') | |
def knapsack_solution(items): | |
return KnapsackSolution(items, sum(item.value for item in items)) | |
def item(value, weight): | |
assert weight > 0 | |
return Item(value, weight, float(value)/weight) | |
class Items: | |
def __init__(self, items, i=0): | |
self._items = items | |
self._i = i | |
def empty(self): | |
return self._i >= len(self._items) | |
def first(self): | |
return self._items[self._i] | |
def rest(self): | |
return Items(self._items, self._i + 1) | |
def __iter__(self): | |
return (self._items[i] for i in xrange(self._i, len(self._items))) | |
def __len__(self): | |
return 0 if self.empty() else (len(self._items) - self._i) | |
class Heap: | |
# "heap property": a binary tree has the heap property established if | |
# the root node has no descendants or is smaller than or equal to all | |
# descendants and the left and right subtrees also have the heap property | |
# established. | |
# For all big-oh notations below, N = len(self._heap). | |
def __init__(self, elements, key=lambda x: x): | |
self._key = key | |
self._heap = elements[:] # heap property not necessarily established | |
self._heapify() # heap property established | |
def __bool__(self): | |
return len(self._heap) > 0 | |
def __repr__(self): | |
return repr(self._heap) | |
def push(self, element): | |
# precondition: heap property exists | |
self._heap.append(element) # heap property possibly violated; O(1) time | |
self._bubble_up() # heap property restored; O(logN) time | |
def pop(self): | |
# precondition: heap property exists | |
assert len(self._heap) > 0 | |
if len(self._heap) == 1: | |
return self._heap.pop() # heap property maintained; O(1) time | |
else: | |
head_element = self._heap[0] | |
self._heap[0] = self._heap.pop() # heap property possibly violated; O(1) time | |
self._bubble_down(0) # heap property restored; O(logN) time | |
return head_element | |
def _bubble_up(self): | |
# O(logN) time | |
assert len(self._heap) > 0 | |
parent_index = lambda i: (i + 1) // 2 - 1 | |
parent = lambda i: self._heap[parent_index(i)] | |
parent_key = lambda i: self._key(parent(i)) | |
element_index = len(self._heap) - 1 | |
element = self._heap[element_index] | |
element_key = self._key(element) | |
while element_index != 0 and element_key < parent_key(element_index): | |
self._heap[element_index] = parent(element_index) | |
element_index = parent_index(element_index) | |
self._heap[element_index] = element | |
def _bubble_down(self, element_index): | |
# O(logN) time | |
# precondition: the left and right subtrees of element have the heap property | |
assert 0 <= element_index < len(self._heap) | |
element = self._heap[element_index] | |
element_key = self._key(element) | |
child_indices = lambda i: filter(lambda j: j < len(self._heap), (i * 2 + 1, i * 2 + 2)) | |
has_children = lambda i: len(child_indices(i)) > 0 | |
index_of_smaller_child = lambda i: min(child_indices(i), key=lambda j: self._key(self._heap[j])) | |
smaller_child = lambda i: self._heap[index_of_smaller_child(i)] | |
key_of_smaller_child = lambda i: self._key(smaller_child(i)) | |
while has_children(element_index) and element_key > key_of_smaller_child(element_index): | |
self._heap[element_index] = smaller_child(element_index) | |
element_index = index_of_smaller_child(element_index) | |
self._heap[element_index] = element | |
def _heapify(self): | |
# O(N) time. Why? The analysis is subtle. See the following: | |
# https://www.cs.umd.edu/~meesh/351/mount/lectures/lect14-heapsort-analysis-part.pdf | |
# Here's another interesting perspective to check out: | |
# https://math.stackexchange.com/questions/181022/worst-case-analysis-of-max-heapify-procedure | |
n = len(self._heap) | |
indices_from_bottom_to_top = reversed(xrange(n)) | |
for i in indices_from_bottom_to_top: | |
self._bubble_down(i) | |
# Note: for simplicity, we iterate over the indices [n-1, n-2, ..., 1, 0]. | |
# However, there's no need to heapify any leaf node (e.g., node n-1) since they | |
# all already have the heap property established. Therefore, we could introduce a | |
# small optimization by setting setting n above to one plus the largest index for a | |
# non-leaf node in self._heap. What is the largest index for a non-leaf node? The last | |
# non-leaf node must be the parent of the very last element in the heap, which has an | |
# index of n-1. The parent index of n-1 is (((n-1) + 1) // 2 - 1) = n // 2 - 1. | |
# Therefore, we could get the optimization above by setting n = 1 + (n // 2 - 1) = n // 2, | |
# in which case we would just iterate over [n//2 - 1, n//2 - 2, ..., 1, 0], saving some time. | |
# Either way, the algorithm is correct and has the same asymptotic runtime. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tests(): | |
# test cases TODO: | |
# https://lagunita.stanford.edu/courses/course-v1:Engineering+Algorithms2+SelfPaced/courseware/adfcef80aaab466aafedce7fa4b23495/24de6a5b1a94465fb517616fee39abd1/ | |
def pack_items(value_weight_pairs): | |
return map(lambda args: item(*args), value_weight_pairs) | |
# from http://www.spoj.com/problems/KNAPSACK/ | |
weight_limit = 4 | |
items = pack_items([ | |
(8, 1), # (value, weight) | |
(4, 2), | |
(0, 3), | |
(5, 2), | |
(3, 2) | |
]) | |
expected_value = 13 | |
assert knapsack_branch_and_bound(items, weight_limit).value == expected_value | |
# from https://www.geeksforgeeks.org/knapsack-problem/ | |
weight_limit = 50 | |
items = pack_items([ | |
(60, 10), | |
(100, 20), | |
(120, 30) | |
]) | |
expected_value = 220 | |
assert knapsack_branch_and_bound(items, weight_limit).value == expected_value | |
print 'Tests pass!' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment