Last active
November 23, 2021 05:45
-
-
Save Kienyew/01b469d5da26e3656d6c49dfebe3aaaa to your computer and use it in GitHub Desktop.
DPLL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
from pprint import pprint | |
# A Literal is P or ¬P | |
class Literal: | |
@classmethod | |
def from_string(cls, s: str) -> 'Literal': | |
if s.startswith('-'): | |
return Literal(s[1:], True) | |
else: | |
return Literal(s, False) | |
def __init__(self, symbol: str, negate: bool): | |
self.symbol = symbol | |
self.negate = negate | |
def __repr__(self) -> str: | |
if self.negate: | |
return '¬' + self.symbol | |
else: | |
return self.symbol | |
def value(self, truth_value: bool) -> bool: | |
if self.negate: | |
return not truth_value | |
else: | |
return truth_value | |
# A Clause is a minima | |
class Clause: | |
def __init__(self, literals: list[Literal]): | |
self.literals = literals | |
def evaluable(self, assignments) -> bool: | |
return all(literal.symbol in assignments for literal in self.literals) | |
def evaluate(self, assignments) -> bool: | |
return any(literal.value(assignments[literal.symbol]) | |
is True for literal in self.literals) | |
def __repr__(self) -> str: | |
return '∨'.join(str(literal) for literal in self.literals) | |
# A `cnf` is Conjuntive Normal Form | |
# `records` is used to record the steps of algorithm | |
def dpll(cnf: list[Clause], records: list): | |
symbols = set() | |
for clause in cnf: | |
for literal in clause.literals: | |
symbols.add(literal.symbol) | |
symbols = list(symbols) | |
# Try deduce a truth value of a symbol given an assignments | |
def deduce(assignments) -> Optional[tuple[str, bool]]: | |
for clause in cnf: | |
if sum(literal.symbol not in assignments for literal in clause.literals) != 1: | |
continue | |
unit = None | |
for literal in clause.literals: | |
if literal.symbol not in assignments: | |
unit = literal | |
break | |
others = [*filter(lambda x: x != unit, clause.literals)] | |
if all(literal.value(assignments[literal.symbol]) is False for literal in others): | |
if unit.negate: | |
return (unit.symbol, False) | |
else: | |
return (unit.symbol, True) | |
return None | |
def backtrack(assignments: dict[str, bool]): | |
# Deduction | |
deductions = [] | |
while (deduction := deduce(assignments)): | |
symbol, truth = deduction | |
assignments[symbol] = truth | |
deductions.append(symbol) | |
records.append(('deduce', symbol, truth)) | |
# Terminating condition | |
all_evaluable = True | |
for clause in cnf: | |
if clause.evaluable(assignments): | |
if clause.evaluate(assignments) is False: | |
# undo | |
for symbol in deductions: | |
assignments.pop(symbol) | |
records.append(('undo', deductions)) | |
return None | |
else: | |
all_evaluable = False | |
if all_evaluable: | |
if all(clause.evaluate(assignments) is True for clause in cnf): | |
records.append(('complete', assignments)) | |
return assignments | |
# Guessing | |
for symbol in symbols: | |
if symbol in assignments: | |
continue | |
for guess in (True, False): | |
assignments[symbol] = guess | |
records.append(('guess', symbol, guess)) | |
if (result := backtrack(assignments)): | |
return result | |
# undo | |
assignments.pop(symbol) | |
records.append(('undo', [symbol])) | |
# undo | |
for symbol in deductions: | |
assignments.pop(symbol) | |
records.append(('undo', [symbol])) | |
return None | |
return backtrack({}) | |
# Building cnf | |
cnf = [] | |
cnf_s = [ | |
('a', 'c', 'd'), | |
('a', '-c', '-d'), | |
('-a', '-c', '-d'), | |
('-a', 'b', '-f'), | |
('a', 'b', 'f'), | |
('-a', '-b', '-g'), | |
('-a', '-c', 'd'), | |
('-b', 'c', '-e'), | |
('b', 'd', 'e'), | |
('b', 'd', '-e'), | |
('b', 'c', '-d'), | |
('-b', 'c', 'e')] | |
for clause_s in cnf_s: | |
clause = Clause([]) | |
for literal_s in clause_s: | |
literal = Literal.from_string(literal_s) | |
clause.literals.append(literal) | |
cnf.append(clause) | |
print('∧'.join(map(lambda clause: f'({clause})', cnf))) | |
records = [] | |
dpll(cnf, records) | |
pprint(records, width=50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment