Skip to content

Instantly share code, notes, and snippets.

@rrika
Created January 19, 2020 16:35
Show Gist options
  • Save rrika/e676f4bd4ebec25a1017dca444c9f24c to your computer and use it in GitHub Desktop.
Save rrika/e676f4bd4ebec25a1017dca444c9f24c to your computer and use it in GitHub Desktop.
// a sat solver for instances with <= 32 variables
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#define b00000 0
#define b00001 1
#define b00010 2
#define b00011 3
#define b00100 4
#define b00101 5
#define b00110 6
#define b00111 7
#define b01000 8
#define b01001 9
#define b01010 10
#define b01011 11
#define b01100 12
#define b01101 13
#define b01110 14
#define b01111 15
#define b10000 16
#define b10001 17
#define b10010 18
#define b10011 19
#define b10100 20
#define b10101 21
#define b10110 22
#define b10111 23
#define b11000 24
#define b11001 25
#define b11010 26
#define b11011 27
#define b11100 28
#define b11101 29
#define b11110 30
#define b11111 31
// 0 - silent
// 1 - solutions only
// 2 - log branching and propagation results
// 3 - log branching and propagation results (and early exits)
// 4 - log new clauses
// 5 - log derivation of the new clauses
// 6 - as above
// 7 - as above
// 8 - log propagation steps
// 9 - log propagation steps (detailed)
static const int loglevel = 1;
typedef uint32_t Z;
typedef struct C { Z mask; Z neg; int id; } C, Cube, Clause;
struct UnitPropReturn { Cube assignment; Clause *pastFormula; struct Decision *assignLog; };
struct Decision { Z var; int reason; }; // reason = -1 when branching, or the clause id when propagating
static int one_bit_set(Z z) {
return z && !(z & (z - 1));
}
static int next_zero_bit(Z z) {
return (z + 1) & ~z;
}
static C resolve(C a, C b) {
Z pivot = a.mask & b.mask;
if (!one_bit_set(pivot))
return (C){ 0, 0 };
if (((a.neg ^ b.neg) & pivot) != 0)
return (C){ 0, 0 };
Z cmask = a.mask ^ b.mask;
Z cneg = a.neg | b.neg;
return (C){ cmask, cneg & cmask };
}
static void swapc(C *a, C *b) {
C c = *a;
*a = *b;
*b = c;
}
static struct UnitPropReturn unit_propagation(
Cube assignment,
Clause *formula,
Clause *pastFormula,
struct Decision *assignLog)
{
int progress = 1;
while (formula != pastFormula && progress) {
progress = 0;
for (Clause *cursor = formula; cursor < pastFormula; cursor++) {
Z assigned = cursor->mask & assignment.mask;
Z unassigned = cursor->mask & ~assignment.mask;
Z agreement = ~(cursor->neg ^ assignment.neg);
if (loglevel >= 8) printf("0 .. cursor=%d .. pastFormula=%d (ID=%d)\n", cursor-formula, pastFormula-formula, cursor->id);
if (loglevel >= 9) printf("assignment->mask %08x\n", assignment.mask);
if (loglevel >= 9) printf("assignment->neg %08x\n", assignment.neg);
if (loglevel >= 9) printf("cursor->mask %08x\n", cursor->mask);
if (loglevel >= 9) printf("cursor->neg %08x\n", cursor->neg);
if (loglevel >= 9) printf("assigned %08x\n", assigned);
if (loglevel >= 9) printf("unassigned %08x\n", unassigned);
if (loglevel >= 9) printf("agreement %08x\n", agreement);
if (assigned & agreement) {
// this clause is satisfied, swap it to the back
if (loglevel >= 8) printf("clause is sat\n");
} else if (unassigned == 0) {
// there a contradiction
if (loglevel >= 8) printf("clause is unsat\n");
return (struct UnitPropReturn){
{assignment.mask, assignment.neg, cursor->id},
pastFormula,
assignLog
};
} else if (one_bit_set(unassigned)) {
// propagate this
if (loglevel >= 8) printf("clause does prop\n");
assignment.mask |= unassigned;
assignment.neg |= unassigned & cursor->neg;
*assignLog++ = (struct Decision){ unassigned, cursor->id };
progress = 1;
} else {
// nothing to do
continue;
}
// swap to back
swapc(cursor, --pastFormula);
}
}
assignment.id = -1;
return (struct UnitPropReturn){ assignment, pastFormula, assignLog };
}
void printc(C c) {
int n = 0;
putchar('(');
for (int i=0; i<sizeof(Z)*8; i++) {
if (c.mask & (1u << i))
{
if (n++)
putchar(' ');
if (c.neg & (1 << i))
putchar('-');
putchar('i'+i);
}
}
putchar(')');
}
Clause *findClause(Clause *formula, int id) {
while (formula->id != id)
formula++;
return formula;
}
int clauseFalsified(Cube assignment, Clause cl) {
return 0 == (cl.mask & ~(assignment.mask & (cl.neg ^ assignment.neg)));
}
Z analyzeConflict(Cube assignment, Clause *formula, struct Decision *assignLog, Z conflictMask) {
// one day I'll learn about UIP and all that stuff
// but for now, just find out which branch decisions mattered and block those
Z decisionsThatMattered = 0;
Z curious = conflictMask;
Z explained = 0;
while ((--assignLog)->var != 0) {
if (loglevel >= 5) {
printf(" backtracking assignment of %d by %d ", assignLog->var, assignLog->reason);
if (assignLog->reason == -1)
printf("decision\n");
else {
printc(*findClause(formula, assignLog->reason));
printf("\n");
}
}
if (assignLog->var & curious) {
if (loglevel >= 5) printf(" which mattered!\n");
curious &= ~assignLog->var;
explained |= assignLog->var;
if (assignLog->reason == -1) {
// reached a branch decision
decisionsThatMattered |= assignLog->var;
} else {
curious |= findClause(formula, assignLog->reason)->mask;
}
}
}
if (curious & ~explained) {
printf("couldn't explain ");
printc((C){ curious & ~explained, assignment.neg });
printf("\n");
exit(1);
}
if (loglevel >= 4) {
printf("formed new clause: ");
printc((C){ assignment.mask & decisionsThatMattered, (~assignment.neg) & decisionsThatMattered });
printf("\n");
}
return assignment.mask & decisionsThatMattered;
}
Clause *branch(
Cube assignment,
Clause *beforeFormula,
Clause *formula,
Clause *pastFormula,
struct Decision *assignLog)
{
if (loglevel >= 2) {
printf(" ");
printc(assignment);
printf("\n");
}
struct UnitPropReturn upr = unit_propagation(assignment, formula, pastFormula, assignLog);
assignLog = upr.assignLog;
int clausesAdded = 0;
if (loglevel >= 2 && assignment.mask != upr.assignment.mask) {
printf("prop\n ");
printc(upr.assignment);
printf("\n");
}
if (upr.assignment.id != -1) {
// a contradiction
Clause c = *findClause(beforeFormula, upr.assignment.id);
if (loglevel >= 2) {
printf("contradiction: assignment ");
printc(upr.assignment);
printf(" falsifies clause %d ", upr.assignment.id);
printc(c);
printf("\n");
}
Z newClauseMask = analyzeConflict(assignment, beforeFormula, assignLog, c.mask);
if (newClauseMask == assignment.mask) {
if (loglevel >= 4) printf(" new clause not helpful (no smaller than assignment)\n");
} else if (newClauseMask == c.mask && (newClauseMask & ~assignment.neg) == c.neg) {
if (loglevel >= 4) printf(" new clause not helpful (equal to existing clause)\n");
} else {
if (beforeFormula == formula) {
printf("error: no more space for new clauses\n");
exit(1);
}
// add the clause!
formula--;
formula->mask = newClauseMask;
formula->neg = newClauseMask & ~assignment.neg;
}
} else if (formula == upr.pastFormula) {
// a solution
if (loglevel >= 1) {
printf("solution: ");
printc(upr.assignment);
printf("\n");
}
} else {
Z decisionVar = next_zero_bit(upr.assignment.mask);
*assignLog++ = (struct Decision){ decisionVar, -1 };
if (loglevel >= 2) printf("branch on %d (false)\n", decisionVar);
Cube ass0 = { upr.assignment.mask | decisionVar, upr.assignment.neg | decisionVar, -1 };
formula = branch(ass0, beforeFormula, formula, upr.pastFormula, assignLog);
if (clauseFalsified(assignment, *formula)) {
if (loglevel >= 3) printf("leaving early\n");
return formula; // short cut
}
if (loglevel >= 2) printf("branch on %d (true)\n", decisionVar);
Cube ass1 = { upr.assignment.mask | decisionVar, upr.assignment.neg, -1 };
formula = branch(ass1, beforeFormula, formula, upr.pastFormula, assignLog);
}
return formula;
}
void solve(Clause *beforeFormula) {
struct Decision assignLog[1+sizeof(Z)*8];
assignLog[0].var = 0;
assignLog[0].reason = -1;
Cube assignment = {0, 0, -1};
Clause *formula = beforeFormula;
while (!formula->mask)
formula++;
Clause *pastFormula = formula;
while (pastFormula->mask)
pastFormula++;
Clause *f = branch(assignment, beforeFormula, formula, pastFormula, assignLog);
if (loglevel >= 4) {
while (f < formula) {
printf("derived clause: ");
printc(*f++);
printf("\n");
}
}
}
int main(int argc, char **argv) {
Clause formula[] = {
{b00000, b00000, 100},
{b00000, b00000, 101},
{b00000, b00000, 102},
{b00000, b00000, 103},
{b00000, b00000, 104},
{b00000, b00000, 105},
{b00000, b00000, 106},
{b00000, b00000, 107},
{b11000, b10000, 1}, // 10---
{b01100, b01000, 2}, // -10--
{b00110, b00100, 3}, // --10-
{b00011, b00010, 4}, // ---10
{b11001, b00000, 5}, // 00---
{b00110, b00110, 6}, // --11-
{b00000, b00000}
};
solve(formula);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment