Skip to content

Instantly share code, notes, and snippets.

@bfroehle
Created May 28, 2013 00:22

Revisions

  1. bfroehle created this gist May 28, 2013.
    425 changes: 425 additions & 0 deletions kenken.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,425 @@
    """
    A SAT-based KenKen (http://www.kenken.com/) solver.
    The implementation of this solver is based on the ideas contained in
    the paper:
    "A SAT-based Sudoku solver" by Tjark Weber
    https://www.lri.fr/~conchon/mpri/weber.pdf
    and a Python implementation of a SAT-based Sudoku solver which is found
    in the pycosat examples:
    https://pypi.python.org/pypi/pycosat
    This module requires Python 3.3 or later.
    """

    import itertools
    import pycosat

    ########################################################################
    # Exceptions
    #

    class InvalidPuzzle(Exception):
    """The puzzle specification is invalid."""
    pass

    class UnsatisfyablePuzzle(Exception):
    """The puzzle is unsatisfyable."""
    pass


    ########################################################################
    # Partitioners
    #
    # A partitioner yields all possible partitions of `result` into `nparts`
    # according to a specified operation, where each value is in the range
    # `1..dim`.
    #
    # See the docstring of `partition` for some examples.
    #

    def _partition_add(nparts, result, dim):
    if nparts == 0:
    raise ValueError("Expected nparts >= 1")
    elif nparts == 1:
    if 1 <= result <= dim:
    yield (result, )
    else:
    raise StopIteration
    else:
    for i1 in range(1, dim+1):
    for i2 in _partition_add(nparts-1, result-i1, dim):
    yield (i1,) + i2

    def _partition_mul(nparts, result, dim):
    if nparts == 0:
    raise ValueError("Expected nparts >= 1")
    elif nparts == 1:
    if 1 <= result <= dim:
    yield (result, )
    else:
    raise StopIteration
    else:
    for i1 in range(1, dim+1):
    if result % i1 != 0:
    continue
    for i2 in _partition_mul(nparts-1, int(result//i1), dim):
    yield (i1,) + i2

    def _partition_sub(nparts, result, dim):
    if nparts != 2:
    raise ValueError("Expected nparts = 2.")
    for i1 in range(1, dim+1):
    for i2 in (i1-result, i1+result):
    if 1 <= i2 <= dim:
    yield(i1, i2)

    def _partition_div(nparts, result, dim):
    if nparts != 2:
    raise ValueError("Expected nparts = 2.")
    for i1 in range(1, dim+1):
    if i1 % result == 0:
    i2 = int(i1 // result)
    if 1 <= i2 <= dim:
    yield (i1, i2)
    i2 = result * i1
    if 1 <= i2 <= dim:
    yield (i1, i2)

    def _partition_eq(nparts, result, dim):
    if nparts != 1:
    raise ValueError("Expected nparts = 1.")
    if 1 <= result <= dim:
    yield (result,)

    _partitioners = {
    '+': _partition_add,
    '*': _partition_mul,
    '-': _partition_sub,
    '/': _partition_div,
    '!': _partition_eq,
    '=': _partition_eq,
    }

    def partition(op, nparts, result, dim):
    """Partition `result` into `nparts` each in the range `1..dim`
    which can be obtained using the operation `op`.
    For example, all pairs of numbers in the range 1..6 whose ratio is 3:
    >>> sorted(partition('/', 2, 3, 6))
    [(1, 3), (2, 6), (3, 1), (6, 2)]
    All triples of numbers in the range 1..6 whose product is 4:
    >>> sorted(partition('*', 3, 4, 6))
    [(1, 1, 4), (1, 2, 2), (1, 4, 1), (2, 1, 2), (2, 2, 1), (4, 1, 1)]
    All single numbers which are exactly equal to 4:
    >>> sorted(partition('=', 1, 4, 6))
    [(4,)]
    """
    yield from _partitioners[op](nparts, result, dim)

    ########################################################################
    # A cage is a sub-region of the puzzle with a specified arithmetic
    # operation and resulting value.
    #

    class Cage:
    """A set of cells which achives a value by an arithmetic operation."""

    def __init__(self, op, value, cells):
    self.op = op
    self.value = value
    self.cells = cells

    def __str__(self):
    return "%s(%r, %s, %s)" % (
    self.__class__.__name__, self.op, self.value, self.cells)

    def dnf_clauses(self, dim, variable):
    """Yield clauses in disjunctive normal form which satisfy the
    cage constraints.
    Essentially this returns clauses which correspond to
    (digit(cell[0]) == sol1[0] AND digit(cell[1]) == sol1[1] AND ...)
    OR (digit(cell[0]) == sol2[0] AND digit(cell[1]) == sol2[1] AND ...)
    OR (digit(cell[0]) == sol3[0] AND digit(cell[1]) == sol3[1] AND ...)
    OR (digit(cell[0]) == sol4[0] AND digit(cell[1]) == sol4[1] AND ...)
    OR ...
    where sol1, sol2, sol3, sol4, ... are all possible local solutions
    to the cage constraints, ignoring any global constraints (like
    no duplicates in rows or columns).
    Parameters
    ----------
    dim : int
    Puzzle dimension
    variable : callable, as variable(i, j, d)
    Returns the variable number corresponding to `d` in cell `(i,j)`.
    """
    for vals in partition(self.op, len(self.cells), self.value, dim):
    yield tuple(variable(cell[0], cell[1], val)
    for cell, val in zip(self.cells, vals))

    ########################################################################
    # Main Entry Point
    #

    class KenKenPuzzle:
    """A KenKen puzzle."""

    def __init__(self, size, cages=None):
    self.size = size
    self._cages = cages or []

    @classmethod
    def from_text(cls, text):
    """Instantiate a puzzle from a text description."""
    lines = iter(text.splitlines())
    for line in lines:
    line = line.strip()
    if not line:
    continue
    op, value = line.split()
    if op != '#':
    raise InvalidPuzzleSpecification
    size = int(value)
    break
    else:
    raise InvalidPuzzleSpecification

    puzzle = cls(size)
    for line in lines:
    line = line.strip()
    if not line:
    continue
    puzzle.add_text_cage(line)

    puzzle.assert_valid()

    return puzzle

    def assert_valid(self):
    """Checks that each cell belongs to exactly one cage."""
    visited = set()
    for cage in self._cages:
    frontier = set(cage.cells)
    if not visited.isdisjoint(frontier):
    raise InvalidPuzzle("Duplicate cells: %s"
    % (visited.intersection(frontier),))
    visited.update(frontier)

    expected = set(itertools.product(range(1, self.size+1),
    range(1, self.size+1)))
    missing = expected.difference(visited)
    unknown = visited.difference(expected)
    if missing or unknown:
    messages = []
    if missing:
    messages.append("Missing cells: %s" % (missing,))
    if unknown:
    messages.append("Unknown cells: %s" % (unknown,))
    raise InvalidPuzzle(" ".join(messages))

    def add_cage(self, op, result, cells):
    """Add a cage."""
    self._cages.append(Cage(op, result, cells))

    def _cell_as_tuple(self, cell):
    """Convert cell notation from text form to tuple form.
    For example, 'B7' -> (2,7).
    """
    row, col = cell[:1], cell[1:]
    row = ord(row)-ord('A')+1
    col = int(col)
    return row, col

    def add_text_cage(self, text):
    """Add a textual representation of a cage."""
    op, result, *cells = text.split()
    result = int(result)

    cells = tuple(self._cell_as_tuple(cell) for cell in cells)
    self.add_cage(op, result, cells)

    def variable(self, i, j, d):
    """Return the number of the variable which corresponds to cell (i, j)
    containing digit d.
    """
    n = self.size
    assert 1 <= i <= n
    assert 1 <= j <= n
    assert 1 <= d <= n
    return n*n * (i - 1) + n * (j - 1) + d

    def _unvariable(self, v):
    """Return the cell and digit corresponding to v."""
    n = self.size
    i = v // (n*n) + 1
    v = v - n*n*(i-1)
    j = v // n + 1
    v = v - n*(j-1)
    d = v
    return ((i,j), d)

    def clauses(self):
    """Yield all clauses for the puzzle."""
    n = self.size
    v = self.variable

    # For all cells, ensure that the each cell:
    for i in range(1, n+1):
    for j in range(1, n+1):
    # Denotes (at least) one of the n digits.
    yield [v(i,j,d) for d in range(1, n+1)]
    # Does not denote two different digits at once.
    for d in range(1, n+1):
    for dp in range(d+1, n+1):
    yield [-v(i,j,d), -v(i,j,dp)]

    def valid(cells):
    # Ensure that the cells contain distinct values.
    for i, xi in enumerate(cells):
    for j, xj in enumerate(cells):
    if i < j:
    for d in range(1, n+1):
    yield [-v(xi[0],xi[1],d), -v(xj[0],xj[1],d)]

    # Ensure rows and columns have distinct values.
    for i in range(1, n+1):
    yield from valid([(i, j) for j in range(1, n+1)])
    yield from valid([(j, i) for j in range(1, n+1)])

    # The cages return their clauses in disjunctive normal form,
    # but our SAT solver needs the clauses in conjunctive normal
    # form. To convert from DNF to CNF without exponential growth
    # in the number of clauses we introduce additional variables.
    auxiliary_vars = itertools.count(n**3+1)

    # For each cage:
    for cage in self._cages:
    dnf = list(cage.dnf_clauses(n, v))
    if not dnf:
    raise ValueError("Invalid cage: %s" % (cage,))

    # yield the clauses in conjunctive normal form,
    # adding auxiliary variables as necessary
    yield from self._dnf_to_cnf(dnf, auxiliary_vars)

    @staticmethod
    def _dnf_to_cnf(dnf, auxiliary_vars):
    """Convert dnf to cnf.
    Parameters
    ----------
    dnf : list of lists of int
    Clauses in disjunctive normal form
    auxiliary_variables: iterator, returning int
    Auxiliary variables, of which len(dnf) will be taken.
    """
    # Take the first `len(dnf)` entries from auxiliary_vars.
    auxs = list(itertools.islice(auxiliary_vars, len(dnf)))
    yield auxs
    for v, clause in zip(auxs, dnf):
    for c in clause:
    yield [-v, c]
    yield [v] + [-c for c in clause]

    def solve(self):
    """Return a solution to the KenKen puzzle."""
    sol = pycosat.solve(self.clauses())
    if sol == 'UNSAT':
    raise InvalidPuzzle
    return self._sol_to_grid(set(sol))

    def itersolve(self):
    """Return an iterator to all solutions of the KenKen puzzle."""
    for sol in pycosat.itersolve(self.clauses()):
    yield self._sol_to_grid(set(sol))

    def _sol_to_grid(self, sol):
    """Convert a solution to a grid format."""
    def read_cell(i, j):
    # return the digit of cell i, j according to the solution
    for d in range(1, self.size+1):
    if self.variable(i, j, d) in sol:
    return d

    grid = [[None]*self.size for _ in range(self.size)]
    for i in range(1, self.size+1):
    for j in range(1, self.size+1):
    grid[i-1][j-1] = read_cell(i, j)

    return grid


    if __name__ == "__main__":

    # Apologies if there is a standard format for representing
    # KenKen puzzles which I am not aware of.

    # The puzzle format here is the following:
    #
    # # <dim>
    # <cage1>
    # ...
    # <cageN>
    #
    # where dim is the dimension of the puzzle (i.e., 6 if
    # the grid is 6-by-6). Each cage is of the form
    #
    # <op> <value> <cell1> ... <cellN>
    #
    # where op is one of +, -, *, /, and = (or !), value
    # is the resulting number, and the cells are specified
    # as <row><column> where the rows are labeled A, B, C,
    # etc and the columns are labeled 1, 2, 3, etc.
    #

    print("Solving the 6x6 KenKen Puzzle from Wikipedia.")
    print("See http://en.wikipedia.org/wiki/File:KenKenProblem.svg")

    puzzle = KenKenPuzzle.from_text(
    """
    # 6
    + 11 A1 B1
    / 2 A2 A3
    * 20 A4 B4
    * 6 A5 A6 B6 C6
    - 3 B2 B3
    / 3 B5 C5
    * 240 C1 C2 D1 D2
    * 6 C3 C4
    * 6 D3 E3
    + 7 D4 E4 E5
    * 30 D5 D6
    * 6 E1 E2
    + 9 E6 F6
    + 8 F1 F2 F3
    / 2 F4 F5
    """)
    grid = puzzle.solve()

    # Compare to the known solution at
    # http://en.wikipedia.org/wiki/File:KenKenSolution.svg

    assert grid == [[5,6,3,4,1,2],
    [6,1,4,5,2,3],
    [4,5,2,3,6,1],
    [3,4,1,2,5,6],
    [2,3,6,1,4,5],
    [1,2,5,6,3,4]]

    for row in grid:
    print(row)
    print()