Last active
December 21, 2016 05:15
-
-
Save ellenhp/688fc80db01cfd5a1d97d1e634ed0773 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright 2016 Ellen Poe | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
from random import shuffle, randint | |
import itertools | |
import operator | |
import functools | |
def subsetsum(array,num): | |
if num == 0 or num < 1: | |
return None | |
elif len(array) == 0: | |
return None | |
else: | |
if array[0] == num: | |
return [array[0]] | |
else: | |
with_v = subsetsum(array[1:],(num - array[0])) | |
if with_v: | |
return [array[0]] + with_v | |
else: | |
return subsetsum(array[1:],num) | |
def ncr(n, r): | |
r = min(r, n-r) | |
if r == 0: return 1 | |
numer = functools.reduce(operator.mul, range(n, n-r, -1)) | |
denom = functools.reduce(operator.mul, range(1, r+1)) | |
return numer//denom | |
class SubsetSumSolver: | |
def __init__(self, s, maxM=7, directCheckThresh=5000): | |
self.maxM = maxM | |
self.directCheckThresh = directCheckThresh | |
self.allMs = [2**i for i in range(1, maxM)] | |
self.set = s | |
#setModKey[m][k] is the number of values that satisfy: val === k (mod m) | |
#it will only be defined in for m = power of two | |
self.setModKey = dict() | |
#valEquivalenceClasses[m][k] is the list of all values that satsify val === k (mod m) | |
self.valEquivalenceClasses = dict() | |
#this caches some expansion operations | |
self.combinationCache = dict() | |
self.combinationSourceCache = dict() | |
for m in self.allMs: | |
self.setModKey[m] = [] | |
self.valEquivalenceClasses[m] = dict() | |
self.combinationCache[m] = dict() | |
for k in range(m): | |
satisfyingVals = [val for val in self.set if val % m == k] | |
self.valEquivalenceClasses[m][k] = satisfyingVals | |
self.setModKey[m].append(len(satisfyingVals)) | |
def solveDirectly(self, goal, valCounts): | |
m = len(valCounts) | |
#Fill valLists this with a list of 'sets' by taking combinations of our equivalence classes according to valCounts. | |
#Taking the cartesian product of valLists will give us every possible combination of values that satisfies valCounts. | |
valLists = [] | |
for i in range(m): | |
if valCounts[i] == 0: | |
continue | |
elif i in self.combinationCache[m].keys() and valCounts[i] in self.combinationCache[m][i].keys(): | |
valLists.append(self.combinationCache[m][i][valCounts[i]]) | |
else: | |
if i not in self.combinationCache[m].keys() and valCounts[i]: | |
self.combinationCache[m][i] = dict() | |
possibleCombinations = [sum(sublist) for sublist in list(itertools.combinations(self.valEquivalenceClasses[m][i], valCounts[i]))] | |
self.combinationCache[m][i][valCounts[i]] = possibleCombinations | |
valLists.append(possibleCombinations) | |
#This is the previously mentioned cartesian product | |
candidateSolutions = itertools.product(*valLists) | |
for candidate in candidateSolutions: | |
#Flatten it in order to take a sum, and check against the goal. | |
if sum(candidate) == goal: | |
k = 0 | |
deconstructedCandidate = [] | |
for i in range(m): | |
if valCounts[i] == 0: | |
continue | |
#don't increment k because no sum was added to this candidate earlier. | |
else: | |
possibleCombinations = list(itertools.combinations(self.valEquivalenceClasses[m][i], valCounts[i])) | |
for sublist in possibleCombinations: | |
if sum(sublist) == candidate[k]: | |
deconstructedCandidate += sublist | |
k += 1 | |
return deconstructedCandidate | |
def solve(self, goal): | |
#A simple optimization is to not consider candidate solutions that cannot possibly have enough values to satisfy the goal | |
minToTake = None | |
partialSums = [sum(sorted(self.set, reverse=True)[:i]) for i in range(len(self.set))] | |
for i in range(len(self.set)): | |
if partialSums[i] >= goal: | |
minToTake = i | |
break | |
solution = self.satisfySetForM(goal, None, 2, minToTake) | |
if solution is not None: | |
solution = list(solution) | |
#The solution space reduction factor can only be calculated if we completed looking through the solution space (and didn't solve the problem) | |
return solution | |
def satisfySetForM(self, goal, prevCounts, m, minToTake): | |
goalModM = goal % m | |
#This will hold all conceivable totals based solely on prevCounts. Validation will come later. | |
allPossibleCounts = [] | |
if prevCounts is None: | |
#This is the first iteration! Just come up with every possible value. | |
lists = [range(k+1) for k in self.setModKey[m]] | |
allPossibleCounts = list(itertools.product(*lists)) | |
else: | |
#This code is hard to understand, but its function is best described in terms of the algorithm. | |
#Consider the count, k3_4, of values where n === 3 (mod 4) | |
#The counts k3_8 and k7_8 of values where n === 3 (mod 8) and n === 7 (mod 8) must sum to k3_4 | |
#This code goes from a known value of k3_4 to all possible values of k3_8 and k7_8 | |
possiblePairs = [] | |
for i in range(int(m/2)): | |
#need to satisfy a sum of prevCounts[i] with groups of sizeLow and sizeHigh | |
sizeLow, sizeHigh = self.setModKey[m][i], self.setModKey[m][i + int(m/2)] | |
minLow = max(0, prevCounts[i] - sizeHigh) | |
maxLow = min(sizeLow, prevCounts[i]) | |
minHigh = max(0, prevCounts[i] - sizeLow) | |
maxHigh = min(sizeHigh, prevCounts[i]) | |
currentPossiblePairs = zip(range(minLow, maxLow + 1), reversed(range(minHigh, maxHigh + 1))) | |
possiblePairs.append(currentPossiblePairs) | |
#Interleave the pairs in possiblePairs back into the structure that prevCounts uses. | |
for item in itertools.product(*possiblePairs): | |
allPossibleCounts.append([pair[0] for pair in item] + [pair[1] for pair in item]) | |
#satisfyingCounts holds everything in allPossibleCounts that leads to a sum congruent with goal (mod m) | |
satisfyingCounts = [] | |
for item in allPossibleCounts: | |
#There's no way this count could lead to a solution | |
if sum(item) < max(1, minToTake): | |
continue | |
#Modular addition stuff. | |
sumCounts = 0 | |
for k in range(m): | |
sumCounts += k * item[k] | |
if sumCounts % m == goalModM: | |
#Decide whether or not to look for solutions directly or further reduce the solution space | |
possibleSolutions = functools.reduce(operator.mul, [ncr(self.setModKey[m][k], item[k]) for k in range(m)], 1) | |
if possibleSolutions < self.directCheckThresh: | |
solution = self.solveDirectly(goal, item) | |
if solution is not None: | |
return solution | |
else: | |
satisfyingCounts.append(item) | |
if len(satisfyingCounts) == 0: | |
return None | |
#Recurse or bottom out and look for solutions with what we have. | |
if m * 2 in self.setModKey.keys(): | |
for possibleCounts in satisfyingCounts: | |
solution = self.satisfySetForM(goal, possibleCounts, m * 2, minToTake) | |
if solution is not None: | |
return solution | |
else: | |
for possibleCounts in satisfyingCounts: | |
solution = self.solveDirectly(goal, possibleCounts) | |
if solution is not None: | |
return solution | |
return None | |
s = [5323497, 1375142, 7914384, 6328621, 3197911, 3171174, 1041349, 393355, 6351908, 5438485, 6818284, 1688390, 9648209, 6947185, 1398236, 9830772, 6815854, 4714851, 4595166, 3748144, 1289680, 2434376, 4300956, 9292527, 6518238] | |
goal = 77016837 | |
from time import time | |
print('Starting proposed solver.') | |
startTime = time() | |
solver = SubsetSumSolver(s) | |
print('Solution: {}'.format(sorted(solver.solve(goal)))) | |
print('Elapsed time for proposed solver: {}\n'.format(time() - startTime)) | |
print('Starting dynamic programming solver') | |
startTime = time() | |
print('Solution: {}'.format(sorted(subsetsum(s, goal)))) | |
print('Elapsed time for dynamic programming solution: {}'.format(time() - startTime)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment