Created
November 15, 2016 17:40
-
-
Save evertheylen/4df34243635754b412615cac02a41119 to your computer and use it in GitHub Desktop.
AI CSP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from collections import defaultdict | |
from itertools import combinations | |
from util import * | |
latex = False | |
class Constraint: | |
def __init__(self, variables, func, doc=""): | |
self.variables = variables # ordered! | |
self.func = func | |
self.doc = doc | |
def valid(self, assignment): | |
args = [pick_first(d) for d in (assignment.domains[var] for var in self.variables) if len(d) == 1] | |
if len(args) == len(self.variables): | |
return self.func(*args) | |
else: | |
return True # Unsure, so continue | |
def is_binary(self): | |
return len(self.variables) == 2 | |
def is_unary(self): | |
return len(self.variables) == 1 | |
# functional! | |
class Solution: | |
@classmethod | |
def from_csp(cls, csp): | |
return cls(set(csp.variables.keys()), dict_value_copy(csp.variables), csp) | |
def __init__(self, unassigned, domains, csp): | |
self.unassigned = unassigned | |
self.domains = domains | |
self.csp = csp | |
def is_complete(self): | |
return len(self.unassigned) == 0 | |
def can_continue(self): | |
return self.csp.valid(self) | |
def has_empty_domain(self): | |
return any(len(v) == 0 for v in self.domains.values()) | |
def copy(self): | |
return type(self)(self.unassigned.copy(), dict_value_copy(self.domains), self.csp) | |
def assign(self, x, x_val): | |
new = self.copy() | |
new.domains[x] = {x_val} | |
new.unassigned.remove(x) | |
return new | |
def recheck_domains(self): | |
return self | |
def _latex_str(self): | |
"RST/latex string" | |
l = [] | |
l.append(".. math::") | |
l.append("\t\\begin{aligned}") | |
for k, v in sorted(self.domains.items(), key=str): | |
l.append("\t{} & \\rightarrow \\{{ {} \\}} \\\\".format(k, ", ".join( | |
"\\text{{{}}}".format(i) for i in v))) | |
l.append("\t\end{aligned}") | |
return "\n".join(l) | |
def _normal_str(self): | |
l = [] | |
for k, v in sorted(self.domains.items(), key=str): | |
l.append("{}\t --> {}".format(k, v)) | |
return "\n".join(l) | |
__str__ = _latex_str if latex else _normal_str | |
class ForwardCheckingSolution(Solution): | |
# Warning: only for binary constraints! | |
def can_continue(self): | |
if self.csp.all_binary: | |
return not self.has_empty_domain() | |
else: | |
return super().can_continue() | |
def assign(self, x, x_val): | |
new = super().assign(x, x_val) | |
# forward checking | |
for y in new.unassigned: | |
for c in self.csp.constraint_info.get(x, y): | |
if c.is_binary(): | |
for y_val in new.domains[y].copy(): | |
if not c.func(**{x:x_val, y:y_val}): | |
new.domains[y].remove(y_val) | |
return new | |
class AC3Solution(Solution): | |
def recheck_domains(self, *, Q=None): | |
new = self.copy() | |
contradiction = False | |
if Q is None: | |
Q = list(self.domains.keys()) | |
while len(Q) != 0 and not contradiction: | |
#print("\nStarting AC3 iteration. Q = {}".format(Q)) | |
x = Q.pop() | |
#print("x = {}".format(x)) | |
for y, constraint in new.csp.constraint_info.related_to(x): | |
#print(" y = {}, constraint.vars = {}".format(y, constraint.variables)) | |
if constraint.is_binary() and y in new.unassigned and new.remove_values(x, y, constraint): | |
#print(" Removed values!") | |
if len(new.domains[y]) == 0: contradiction = True | |
Q.insert(0, y) | |
return new | |
# mutates | |
def remove_values(self, x, y, constraint): | |
removed = False | |
for y_val in self.domains[y].copy(): | |
#print(" checking x = {} in {}, y = {} = {}".format(x, self.domains[x], y, y_val)) | |
if not any(constraint.func(**{x:x_val, y:y_val}) for x_val in self.domains[x]): | |
self.domains[y].remove(y_val) | |
removed = True | |
return removed | |
class AC3ForwardSolution(AC3Solution, ForwardCheckingSolution): | |
def recheck_domains(self): | |
# Q can just be the uninitialised values if we do forward checking | |
return self.recheck_domains(Q=self.uninitialised.copy()) | |
class ConstraintInfo: | |
def __init__(self, constraints, max_len = 2): | |
self.vars_to_cons = multimap() | |
self.var_to_var = defaultdict(multimap) | |
for c in constraints: | |
for l in range(1, max_len+1): | |
for p in combinations(sorted(c.variables), r=l): | |
self.vars_to_cons[p].add(c) | |
for x in c.variables: | |
for y in c.variables: | |
if x != y: | |
self.var_to_var[x][y].add(c) | |
def get(self, *vars): | |
return self.vars_to_cons[tuple(sorted(vars))] | |
def related_to(self, x): | |
# returns every variable y related to x by a constraint (which is also returned) | |
return self.var_to_var[x].flat_items() | |
# Strategies | |
class SelectVar: | |
def first(csp, unassigned): | |
return pick_first(unassigned) | |
class OrderDomain: | |
def same(csp, domain): | |
return domain | |
class CSP(DotOutput): | |
def __init__(self, problem): | |
self.problem = problem | |
self.variables = problem.variables.copy() | |
self.constraints = [] | |
for lam in self.problem.constraints: | |
if isinstance(lam, Constraint): | |
self.constraints.append(lam) | |
else: | |
spec = inspect.getargspec(lam) | |
self.constraints.append(Constraint(spec.args, lam)) | |
self.all_binary = all(c.is_binary() for c in self.constraints) | |
self.constraint_info = ConstraintInfo(self.constraints) | |
def valid(self, assignment): | |
return all(c.valid(assignment) for c in self.constraints) | |
def solve(self, cls = ForwardCheckingSolution, | |
select_var = SelectVar.first, | |
order_domain = OrderDomain.same): | |
return self._solve(Solution.from_csp(self), select_var, order_domain) | |
def _solve(self, A, select_var, order_domain): | |
#print("\n_solve called, A =") | |
#print(A) | |
if A.is_complete(): | |
assert self.valid(A) | |
return A | |
A = A.recheck_domains() | |
if A is None: | |
return None | |
x = select_var(self, A.unassigned) | |
#print("selected x = ", x) | |
for val in order_domain(self, A.domains[x]): | |
new_A = A.assign(x, val) | |
#print("{} = {}".format(x, val)) | |
if new_A.can_continue(): | |
#print("can continue!") | |
result = self._solve(new_A, select_var, order_domain) | |
if result is not None: | |
return result | |
return None | |
# Dot output methods | |
def get_dot_nodes(self): | |
for v in self.variables.keys(): | |
yield Node(v) | |
def get_dot_transitions(self): | |
for c in self.constraints: | |
if c.is_binary(): | |
yield Transition(c.variables[0], c.variables[1]) | |
else: | |
print("WARNING: can't draw a non-binary constraint") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from csp import * | |
def question(i): | |
header = "Question {}".format(i) | |
line = "-"*len(header) | |
print("\n\n" + header + "\n" + line + "\n") | |
class Problem: | |
variables = { | |
"S": {"to", "ch", "pu"}, | |
"M": {"me", "co", "pi"}, | |
"V": {"qu", "ta", "sr"}, | |
"G": {"sa", "st", "tu"}, | |
} | |
constraints = [ | |
lambda M, S: S == "pu" or M == "me" or M == "co", | |
lambda S, V: (S != "to") or (V != "ta" and V != "sr"), | |
lambda V, G: not (V == "qu" and (G == "st" or G == "tu")) | |
] | |
c = CSP(Problem()) | |
print("Domains:") | |
print(Solution.from_csp(c)) | |
question(1) | |
c.save_dot("constraint_graph.dot") | |
print("see constraint_graph.dot") | |
question(2) | |
A = ForwardCheckingSolution.from_csp(c) | |
print(A.assign("S", "to")) | |
question(3) | |
A = AC3Solution.from_csp(c) | |
print(A.assign("S", "to").recheck_domains()) | |
question(4) | |
print("Simple:") | |
print(c.solve()) | |
print("\nForward checking:") | |
print(c.solve(ForwardCheckingSolution)) | |
print("\nAC3:") | |
print(c.solve(AC3Solution)) | |
print("\nBoth:") | |
print(c.solve(AC3ForwardSolution)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from itertools import groupby, chain, combinations | |
from collections import defaultdict | |
# Etc | |
# === | |
def dict_value_copy(d): | |
return {k: v.copy() for k, v in d.items()} | |
def pick_first(it): | |
for i in it: | |
return i | |
# Multimap | |
# ======== | |
class multimap(defaultdict): | |
def __init__(self, *a, **kw): | |
super().__init__(set, *a, **kw) | |
@classmethod | |
def from_pairs(cls, l: list): | |
dct = cls() | |
for k, v in l: | |
dct[k].add(v) | |
return dct | |
def flat_items(self): | |
for k, values in self.items(): | |
for v in values: | |
yield k, v | |
def flat_values(self): | |
for values in self.values(): | |
for v in values: | |
yield v | |
def flat_len(self): | |
s = 0 | |
for v in self.values(): | |
s += len(v) | |
return s | |
def flatten(self): | |
d = {} | |
for k, v in self.items(): | |
if len(v) != 1: | |
raise NotFlat(v) | |
d[k] = v.pop() | |
return d | |
# Dot helpers | |
# =========== | |
class Node: | |
def __init__(self, name): | |
self.name = name | |
def to_dot(self): | |
return 'node [shape=circle] "{name}";'.format(name = self.name) | |
class Transition: | |
def __init__(self, frm, to, label: str = None): | |
self.frm = frm | |
self.to = to | |
self.label = label | |
def to_dot(self): | |
s = '"{}" -- "{}"'.format(self.frm, self.to) | |
s += ' [label="{}"]'.format(self.label) if self.label else '' + ';' | |
return s | |
def key(self): | |
return (self.frm, self.to) | |
def dot_str(obj): | |
if isinstance(obj, (str, String)): | |
if len(obj) == 0: | |
return "ε" | |
elif isinstance(obj, String): | |
return str(obj).strip("`") | |
return str(obj) | |
dot_fmt = """ | |
graph csp {{ | |
rankdir=LR | |
{nodes} | |
{transitions} | |
}} | |
""" | |
class Dot: | |
def __init__(self, nodes = [], transitions = []): | |
self.nodes = nodes | |
self.transitions = transitions | |
def to_dot(self): | |
nodes = "\n".join(" " + n.to_dot() for n in self.nodes) | |
_transitions = sorted(self.transitions, key = lambda t: t.key()) | |
_transitions = [Transition(k[0], k[1], "\\n".join(t.label for t in g if t.label is not None)) | |
for k, g in groupby(_transitions, key = lambda t: t.key())] | |
transitions = "\n".join(" " + t.to_dot() for t in _transitions) | |
return dot_fmt.format(nodes = nodes, transitions = transitions) | |
class DotOutput: | |
def save_dot(self, fname): | |
d = Dot(list(self.get_dot_nodes()), list(self.get_dot_transitions())) | |
with open(fname, "w") as f: | |
f.write(d.to_dot()) | |
def get_dot_nodes(self) -> List[Node]: | |
return [Node("default")] | |
def get_dot_transitions(self) -> List[Transition]: | |
return [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment