Created
May 28, 2013 00:22
Revisions
-
bfroehle created this gist
May 28, 2013 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,425 @@ """ A SAT-based KenKen (http://www.kenken.com/) solver. The implementation of this solver is based on the ideas contained in the paper: "A SAT-based Sudoku solver" by Tjark Weber https://www.lri.fr/~conchon/mpri/weber.pdf and a Python implementation of a SAT-based Sudoku solver which is found in the pycosat examples: https://pypi.python.org/pypi/pycosat This module requires Python 3.3 or later. """ import itertools import pycosat ######################################################################## # Exceptions # class InvalidPuzzle(Exception): """The puzzle specification is invalid.""" pass class UnsatisfyablePuzzle(Exception): """The puzzle is unsatisfyable.""" pass ######################################################################## # Partitioners # # A partitioner yields all possible partitions of `result` into `nparts` # according to a specified operation, where each value is in the range # `1..dim`. # # See the docstring of `partition` for some examples. # def _partition_add(nparts, result, dim): if nparts == 0: raise ValueError("Expected nparts >= 1") elif nparts == 1: if 1 <= result <= dim: yield (result, ) else: raise StopIteration else: for i1 in range(1, dim+1): for i2 in _partition_add(nparts-1, result-i1, dim): yield (i1,) + i2 def _partition_mul(nparts, result, dim): if nparts == 0: raise ValueError("Expected nparts >= 1") elif nparts == 1: if 1 <= result <= dim: yield (result, ) else: raise StopIteration else: for i1 in range(1, dim+1): if result % i1 != 0: continue for i2 in _partition_mul(nparts-1, int(result//i1), dim): yield (i1,) + i2 def _partition_sub(nparts, result, dim): if nparts != 2: raise ValueError("Expected nparts = 2.") for i1 in range(1, dim+1): for i2 in (i1-result, i1+result): if 1 <= i2 <= dim: yield(i1, i2) def _partition_div(nparts, result, dim): if nparts != 2: raise ValueError("Expected nparts = 2.") for i1 in range(1, dim+1): if i1 % result == 0: i2 = int(i1 // result) if 1 <= i2 <= dim: yield (i1, i2) i2 = result * i1 if 1 <= i2 <= dim: yield (i1, i2) def _partition_eq(nparts, result, dim): if nparts != 1: raise ValueError("Expected nparts = 1.") if 1 <= result <= dim: yield (result,) _partitioners = { '+': _partition_add, '*': _partition_mul, '-': _partition_sub, '/': _partition_div, '!': _partition_eq, '=': _partition_eq, } def partition(op, nparts, result, dim): """Partition `result` into `nparts` each in the range `1..dim` which can be obtained using the operation `op`. For example, all pairs of numbers in the range 1..6 whose ratio is 3: >>> sorted(partition('/', 2, 3, 6)) [(1, 3), (2, 6), (3, 1), (6, 2)] All triples of numbers in the range 1..6 whose product is 4: >>> sorted(partition('*', 3, 4, 6)) [(1, 1, 4), (1, 2, 2), (1, 4, 1), (2, 1, 2), (2, 2, 1), (4, 1, 1)] All single numbers which are exactly equal to 4: >>> sorted(partition('=', 1, 4, 6)) [(4,)] """ yield from _partitioners[op](nparts, result, dim) ######################################################################## # A cage is a sub-region of the puzzle with a specified arithmetic # operation and resulting value. # class Cage: """A set of cells which achives a value by an arithmetic operation.""" def __init__(self, op, value, cells): self.op = op self.value = value self.cells = cells def __str__(self): return "%s(%r, %s, %s)" % ( self.__class__.__name__, self.op, self.value, self.cells) def dnf_clauses(self, dim, variable): """Yield clauses in disjunctive normal form which satisfy the cage constraints. Essentially this returns clauses which correspond to (digit(cell[0]) == sol1[0] AND digit(cell[1]) == sol1[1] AND ...) OR (digit(cell[0]) == sol2[0] AND digit(cell[1]) == sol2[1] AND ...) OR (digit(cell[0]) == sol3[0] AND digit(cell[1]) == sol3[1] AND ...) OR (digit(cell[0]) == sol4[0] AND digit(cell[1]) == sol4[1] AND ...) OR ... where sol1, sol2, sol3, sol4, ... are all possible local solutions to the cage constraints, ignoring any global constraints (like no duplicates in rows or columns). Parameters ---------- dim : int Puzzle dimension variable : callable, as variable(i, j, d) Returns the variable number corresponding to `d` in cell `(i,j)`. """ for vals in partition(self.op, len(self.cells), self.value, dim): yield tuple(variable(cell[0], cell[1], val) for cell, val in zip(self.cells, vals)) ######################################################################## # Main Entry Point # class KenKenPuzzle: """A KenKen puzzle.""" def __init__(self, size, cages=None): self.size = size self._cages = cages or [] @classmethod def from_text(cls, text): """Instantiate a puzzle from a text description.""" lines = iter(text.splitlines()) for line in lines: line = line.strip() if not line: continue op, value = line.split() if op != '#': raise InvalidPuzzleSpecification size = int(value) break else: raise InvalidPuzzleSpecification puzzle = cls(size) for line in lines: line = line.strip() if not line: continue puzzle.add_text_cage(line) puzzle.assert_valid() return puzzle def assert_valid(self): """Checks that each cell belongs to exactly one cage.""" visited = set() for cage in self._cages: frontier = set(cage.cells) if not visited.isdisjoint(frontier): raise InvalidPuzzle("Duplicate cells: %s" % (visited.intersection(frontier),)) visited.update(frontier) expected = set(itertools.product(range(1, self.size+1), range(1, self.size+1))) missing = expected.difference(visited) unknown = visited.difference(expected) if missing or unknown: messages = [] if missing: messages.append("Missing cells: %s" % (missing,)) if unknown: messages.append("Unknown cells: %s" % (unknown,)) raise InvalidPuzzle(" ".join(messages)) def add_cage(self, op, result, cells): """Add a cage.""" self._cages.append(Cage(op, result, cells)) def _cell_as_tuple(self, cell): """Convert cell notation from text form to tuple form. For example, 'B7' -> (2,7). """ row, col = cell[:1], cell[1:] row = ord(row)-ord('A')+1 col = int(col) return row, col def add_text_cage(self, text): """Add a textual representation of a cage.""" op, result, *cells = text.split() result = int(result) cells = tuple(self._cell_as_tuple(cell) for cell in cells) self.add_cage(op, result, cells) def variable(self, i, j, d): """Return the number of the variable which corresponds to cell (i, j) containing digit d. """ n = self.size assert 1 <= i <= n assert 1 <= j <= n assert 1 <= d <= n return n*n * (i - 1) + n * (j - 1) + d def _unvariable(self, v): """Return the cell and digit corresponding to v.""" n = self.size i = v // (n*n) + 1 v = v - n*n*(i-1) j = v // n + 1 v = v - n*(j-1) d = v return ((i,j), d) def clauses(self): """Yield all clauses for the puzzle.""" n = self.size v = self.variable # For all cells, ensure that the each cell: for i in range(1, n+1): for j in range(1, n+1): # Denotes (at least) one of the n digits. yield [v(i,j,d) for d in range(1, n+1)] # Does not denote two different digits at once. for d in range(1, n+1): for dp in range(d+1, n+1): yield [-v(i,j,d), -v(i,j,dp)] def valid(cells): # Ensure that the cells contain distinct values. for i, xi in enumerate(cells): for j, xj in enumerate(cells): if i < j: for d in range(1, n+1): yield [-v(xi[0],xi[1],d), -v(xj[0],xj[1],d)] # Ensure rows and columns have distinct values. for i in range(1, n+1): yield from valid([(i, j) for j in range(1, n+1)]) yield from valid([(j, i) for j in range(1, n+1)]) # The cages return their clauses in disjunctive normal form, # but our SAT solver needs the clauses in conjunctive normal # form. To convert from DNF to CNF without exponential growth # in the number of clauses we introduce additional variables. auxiliary_vars = itertools.count(n**3+1) # For each cage: for cage in self._cages: dnf = list(cage.dnf_clauses(n, v)) if not dnf: raise ValueError("Invalid cage: %s" % (cage,)) # yield the clauses in conjunctive normal form, # adding auxiliary variables as necessary yield from self._dnf_to_cnf(dnf, auxiliary_vars) @staticmethod def _dnf_to_cnf(dnf, auxiliary_vars): """Convert dnf to cnf. Parameters ---------- dnf : list of lists of int Clauses in disjunctive normal form auxiliary_variables: iterator, returning int Auxiliary variables, of which len(dnf) will be taken. """ # Take the first `len(dnf)` entries from auxiliary_vars. auxs = list(itertools.islice(auxiliary_vars, len(dnf))) yield auxs for v, clause in zip(auxs, dnf): for c in clause: yield [-v, c] yield [v] + [-c for c in clause] def solve(self): """Return a solution to the KenKen puzzle.""" sol = pycosat.solve(self.clauses()) if sol == 'UNSAT': raise InvalidPuzzle return self._sol_to_grid(set(sol)) def itersolve(self): """Return an iterator to all solutions of the KenKen puzzle.""" for sol in pycosat.itersolve(self.clauses()): yield self._sol_to_grid(set(sol)) def _sol_to_grid(self, sol): """Convert a solution to a grid format.""" def read_cell(i, j): # return the digit of cell i, j according to the solution for d in range(1, self.size+1): if self.variable(i, j, d) in sol: return d grid = [[None]*self.size for _ in range(self.size)] for i in range(1, self.size+1): for j in range(1, self.size+1): grid[i-1][j-1] = read_cell(i, j) return grid if __name__ == "__main__": # Apologies if there is a standard format for representing # KenKen puzzles which I am not aware of. # The puzzle format here is the following: # # # <dim> # <cage1> # ... # <cageN> # # where dim is the dimension of the puzzle (i.e., 6 if # the grid is 6-by-6). Each cage is of the form # # <op> <value> <cell1> ... <cellN> # # where op is one of +, -, *, /, and = (or !), value # is the resulting number, and the cells are specified # as <row><column> where the rows are labeled A, B, C, # etc and the columns are labeled 1, 2, 3, etc. # print("Solving the 6x6 KenKen Puzzle from Wikipedia.") print("See http://en.wikipedia.org/wiki/File:KenKenProblem.svg") puzzle = KenKenPuzzle.from_text( """ # 6 + 11 A1 B1 / 2 A2 A3 * 20 A4 B4 * 6 A5 A6 B6 C6 - 3 B2 B3 / 3 B5 C5 * 240 C1 C2 D1 D2 * 6 C3 C4 * 6 D3 E3 + 7 D4 E4 E5 * 30 D5 D6 * 6 E1 E2 + 9 E6 F6 + 8 F1 F2 F3 / 2 F4 F5 """) grid = puzzle.solve() # Compare to the known solution at # http://en.wikipedia.org/wiki/File:KenKenSolution.svg assert grid == [[5,6,3,4,1,2], [6,1,4,5,2,3], [4,5,2,3,6,1], [3,4,1,2,5,6], [2,3,6,1,4,5], [1,2,5,6,3,4]] for row in grid: print(row) print()