Created
January 22, 2017 18:44
-
-
Save lteu/46336816640957651ca46292b9329c7d to your computer and use it in GitHub Desktop.
An improved version for https://gist.github.com/lteu/e3658d6b135fc6c7cea8685f5a1da32e
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Binary Search for COP problem on item selection | |
# ref: http://stackoverflow.com/questions/41708500/how-to-declare-constraints-with-variable-as-array-index-in-z3py | |
from z3 import * | |
import sys | |
import time | |
import json | |
import sys | |
# ====================== | |
# Instance | |
# ====================== | |
# load data | |
with open('data.json') as data_file: | |
dic = json.load(data_file) | |
start_time = time.time() | |
# solver = Solver() | |
solver = Optimize() | |
items = dic['items'] | |
valueVals = dic['value'] | |
bonusVals = dic['bonusVals'] | |
choices = [0,1] | |
arr_low = [] | |
arr_high = [] | |
pivot = valueVals[0] | |
last_pivot = -1 | |
# final score | |
metric = Int('metric') | |
# selection variable | |
CHOICE = [ Int('CHOICE_%s' % i) for i in choices ] | |
litems = len(items) | |
totalValue = Array('TotalValue', IntSort(), IntSort()) | |
for i0 in items: | |
for i1 in items: | |
if i1 > i0: | |
solver.add(totalValue[i0*litems+i1] == valueVals[i0] + valueVals[i1] + bonusVals[i0][i1]) | |
print "array built in ..." | |
array_built_time = time.time() | |
print("--- %s seconds --- ..." % (array_built_time - start_time)) | |
solver.add(0 <= CHOICE[0]) | |
solver.add(CHOICE[0] < CHOICE[1]) | |
solver.add(CHOICE[1] < len(items)) | |
solver.add(metric == (totalValue[CHOICE[0]*litems+CHOICE[1]])) | |
print solver.check() | |
print ("iterative solving ...") | |
solver.push() | |
solver.add(metric > pivot) | |
while last_pivot != pivot: | |
last_pivot = pivot | |
if solver.check() == sat: | |
arr_low.append(pivot) | |
if len(arr_high) == 0: | |
pivot = round(pivot*2) | |
else: | |
pivot = pivot + round((min(arr_high) - pivot)/2) | |
print ("SATISFIABLE") | |
m = solver.model() | |
print [ m.evaluate(CHOICE[i]) for i in choices ] | |
print m.evaluate(metric) | |
else: | |
arr_high.append(pivot) | |
if len(arr_low) == 0: | |
pivot = round(pivot/2) | |
else: | |
pivot = max(arr_low) + round((pivot - max(arr_low) )/2) | |
print "failed to solve" | |
solver.pop() | |
solver.push() | |
solver.add(metric > pivot) | |
z3_time = time.time() | |
print 'new pivot', pivot, 'last ',last_pivot | |
print("--- %s sec ---" % (z3_time - array_built_time)) | |
print "all processes terminated ..." | |
print("--- %s seconds ---" % (time.time() - start_time)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment