Skip to content

Instantly share code, notes, and snippets.

@Trebor-Huang
Created December 11, 2024 13:42
Show Gist options
  • Save Trebor-Huang/6095d3cf797dabf4e21490729ac9a56e to your computer and use it in GitHub Desktop.
Save Trebor-Huang/6095d3cf797dabf4e21490729ac9a56e to your computer and use it in GitHub Desktop.
A simple verifier of metamath
from collections import OrderedDict
import re
def subst(substitution, expr):
return tuple(token
for t0 in expr
for token in (substitution[t0] if t0 in substitution else [t0]))
def decompress(mand, idents, cproof):
"""Decompresses a proof."""
current_number = 0
proof = []
LM = len(mand)
LI = len(idents)
for c in cproof:
if 'U' <= c <= 'Y':
current_number *= 5
current_number += 1 + ord(c) - ord('U')
elif 'A' <= c <= 'T':
current_number *= 20
current_number += 1 + ord(c) - ord('A')
if current_number <= LM:
proof.append(mand[current_number-1])
elif current_number <= LM + LI:
proof.append(idents[current_number - LM-1])
else:
proof.append(current_number - LM - LI-1)
current_number = 0
elif c == 'Z':
proof.append(-1)
current_number = 0
return tuple(proof)
class Verifier:
def __init__(self):
# None-separated stacks
self.variables = []
self.constants = []
self.distinct = [] # list of cliques
self.theorems = {}
self.floats = OrderedDict()
self.floats_name = OrderedDict()
self.float_stack = [0]
self.hypoth = OrderedDict() # name: statement
self.hypoth_stack = [0]
def push(self):
self.variables.append(None)
self.distinct.append(None)
self.float_stack.append(0)
self.hypoth_stack.append(0)
def pop(self):
while self.variables and self.variables.pop() is not None:
pass
while self.distinct and self.distinct.pop() is not None:
pass
for _ in range(self.float_stack.pop()):
self.floats.popitem()
self.floats_name.popitem()
for _ in range(self.hypoth_stack.pop()):
self.hypoth.popitem()
def add_float(self, name, statement):
if len(statement) != 2:
raise ValueError("Bad floating hypothesis: " + ' '.join(statement))
if statement[0] not in self.constants:
raise ValueError("Unrecognized typecode: " + statement[0])
if statement[1] not in self.variables:
raise ValueError("Undeclared variable: " + statement[1])
if statement[1] in self.floats_name:
raise ValueError("Duplicate variable in floating hypothesis:", statement[1])
if name in self.floats_name or name in self.hypoth or name in self.theorems:
raise ValueError("Label already used: " + name)
self.floats[name] = statement
self.floats_name[statement[1]] = statement[0]
self.float_stack[-1] += 1
def check_statement(self, statement):
if len(statement) == 0:
raise ValueError("Empty statement")
if statement[0] not in self.constants:
raise ValueError("Unrecognized typecode: " + statement[0])
for token in statement[1:]:
if token not in self.variables and token not in self.constants:
raise ValueError("Unrecognized token: " + token)
def add_variables(self, variables):
for v in variables:
if v in self.variables or v in self.constants:
raise ValueError(f"Duplicate variable name {v}.")
self.variables.append(v)
def add_constants(self, constants):
for v in constants:
if v in self.variables or v in self.constants:
raise ValueError(f"Duplicate constant name {v}.")
self.constants.append(v)
def add_distinct(self, variables):
for v in variables:
if v not in self.variables:
raise ValueError("Unrecognized variable name.")
self.distinct.append(variables)
def add_essential(self, name, statement):
if name in self.floats_name or name in self.hypoth or name in self.theorems:
raise ValueError("Label already used: " + name)
self.check_statement(statement)
self.hypoth[name] = statement
self.hypoth_stack[-1] += 1
def check_proof(self, proof):
# check distinctness
# returns the statement the proof proves
stack = []
z_references = []
for token in proof:
if token == -1:
if not stack:
raise ValueError("Bad Z reference in compressed proof.")
z_references.append(stack[-1])
elif isinstance(token, int):
if 0 <= token < len(z_references):
stack.append(z_references[token])
else:
raise ValueError("Bad index in compressed proof.")
elif token in self.floats:
stack.append(self.floats[token])
elif token in self.hypoth:
stack.append(self.hypoth[token])
elif token in self.theorems:
fs, es, ds, t = self.theorems[token]
substitution = {}
try:
esh = [stack.pop() for _ in range(len(es))][::-1]
fsh = [stack.pop() for _ in range(len(fs))][::-1]
except IndexError:
raise ValueError("Bad RPN proof: too few arguments")
for i in range(len(fs)):
ty, v = fs[i]
ty0, *expr = fsh[i]
if ty != ty0:
raise ValueError(f"Incorrect type code, expected {ty}, got {ty0}")
substitution[v] = expr
# check distinct
for clique in ds:
for i, v1 in enumerate(clique):
for v2 in clique[:i]:
expr1 = filter(lambda t: t in self.floats_name, substitution[v1])
expr2 = filter(lambda t: t in self.floats_name, substitution[v2])
for t1 in expr1:
for t2 in expr2:
if t1 == t2:
raise ValueError(f"Distinct variable requirement violated: {" ".join(expr1)} and {" ".join(expr2)}")
ok = False
for cq in self.distinct:
if t1 in cq and t2 in cq:
ok = True
break
if not ok:
raise ValueError(f"Distinct variable requirement violated: {" ".join(expr1)} and {" ".join(expr2)}")
for i in range(len(es)):
h = subst(substitution, es[i])
h0 = esh[i]
if h != h0:
raise ValueError(f"Statement mismatch: {" ".join(h)} is not {" ".join(h0)}")
stack.append(subst(substitution, t))
else:
raise RuntimeError("Unrecornized element in parse tree.")
if len(stack) != 1:
raise ValueError("RPN proof incomplete.")
return stack[0]
def add_theorem(self, name, theorem, proof=None):
if name in self.floats_name or name in self.hypoth or name in self.theorems:
raise ValueError("Label already used:", name)
self.check_statement(theorem)
# trims hypotheses
# checks variables are all in floats
used_floats = set()
essentials = tuple(h for h in self.hypoth.values())
for h in essentials + (theorem,):
for token in h:
if token not in self.constants:
if token not in self.floats_name:
raise ValueError("Variable without typecode:", token)
used_floats.add(token)
# TODO more optimizations of distinct
distinct = tuple(tuple(i for i in clique if i in used_floats) for clique in self.distinct if clique is not None)
floats = tuple((k, v) for k, v in self.floats.values() if v in used_floats)
if proof is not None:
if proof and proof[0] == "$compressed":
float_names = tuple(r for r in self.floats if self.floats[r][1] in used_floats)
dproof = decompress(float_names + tuple(self.hypoth), proof[1], proof[2])
result = self.check_proof(dproof)
else:
result = self.check_proof(proof)
if result != theorem:
raise ValueError("The proof proved", " ".join(result), "but we want", " ".join(theorem))
self.theorems[name] = (
floats,
essentials,
distinct,
theorem
)
def verify(verifier, tree):
for statement, result in tree:
if statement == 'scope':
verifier.push()
verifier.verify(result[0])
verifier.pop()
elif statement == 'var':
verifier.add_variables(result[0])
elif statement == 'const':
verifier.add_constants(result[0])
elif statement == 'distinct':
verifier.add_distinct(result[0])
elif statement == 'float':
verifier.add_float(result[0], tuple(result[1:]))
elif statement == 'essential':
verifier.add_essential(result[0], tuple(result[1]))
elif statement == 'axiom':
verifier.add_theorem(result[0], tuple(result[1]))
print("[LOG] Added axiom", result[0])
elif statement == 'theorem':
verifier.add_theorem(result[0], tuple(result[1]), result[2])
print("[LOG] Added theorem", result[0])
def preprocess(tokens: list[str]):
in_comment = False
# no nesting!
for token in tokens:
if in_comment:
if token == "$)":
in_comment = False
continue
if token == "$(":
in_comment = True
continue
# TODO include commands
yield token
# ==== Parser combinators ====
class ParseError(Exception): pass # TODO better errors
# TODO implement a class of peekable streams
def parse_any(parser):
def _any(tokens):
results = []
try:
while True:
tokens, result = parser(tokens)
results.append(result)
except ParseError:
pass
return tokens, results
return _any
def parse_conj(*parsers):
def _conj(tokens):
results = []
for parser in parsers:
tokens, result = parser(tokens)
if result is not None:
results.append(result)
return tokens, results
return _conj
def parse_disj(**parsers):
def _disj(tokens):
for i, parser in parsers.items():
try:
tokens, result = parser(tokens)
return tokens, (i, result)
except ParseError:
continue
raise ParseError("All disjoint branches failed")
return _disj
def consume(r):
"""Consumes a regex or a string literal."""
def _consume(tokens):
if not tokens:
raise ParseError(f"Expecting {r}, got EOF")
if (isinstance(r, str) and r != tokens[0]) or \
(isinstance(r, re.Pattern) and not r.fullmatch(tokens[0])):
raise ParseError(f"Expecting {r}, got {tokens[0]}.")
return tokens[1:], tokens[0]
return _consume
def process(parser, f):
def _process(tokens):
tokens, result = parser(tokens)
return tokens, f(result)
return _process
def expect(r):
return process(consume(r), lambda _: None)
def parse_proof(tokens):
"""Parses the part after $= (not including) and consumes the $."""
if not tokens:
raise ParseError("Expecting proof, got EOF")
if tokens[0] == "(":
# Compressed proof
tokens, (idents, proof) = parse_conj(
expect("("),
parse_any(consume(re.compile("[^)]+"))),
expect(")"),
parse_any(consume(re.compile("[A-Z]+"))),
expect("$.")
)(tokens)
return tokens, ("$compressed", idents, "".join(proof))
else:
# Uncompressed proof
i = 0
while True:
if not tokens[i]:
raise ParseError("Proof not finished with $.")
if tokens[i] == "$.":
return tokens[i+1:], tokens[:i]
i += 1
consume_label = consume(re.compile("[^$]+"))
parse_var = parse_conj(
expect("$v"),
parse_any(consume_label),
expect("$.")
)
parse_const = parse_conj(
expect("$c"),
parse_any(consume_label),
expect("$.")
)
parse_distinct = parse_conj(
expect("$d"),
parse_any(consume_label),
expect("$.")
)
parse_float = parse_conj(
consume_label,
expect("$f"),
consume_label,
consume_label,
expect("$.")
)
parse_essential = parse_conj(
consume_label,
expect("$e"),
parse_any(consume_label),
expect("$.")
)
parse_axiom = parse_conj(
consume_label,
expect("$a"),
parse_any(consume_label),
expect("$.")
)
parse_theorem = parse_conj(
consume_label,
expect("$p"),
parse_any(consume_label),
expect("$="),
parse_proof
)
def parse_scope(tokens):
return parse_any(
parse_disj(
scope=parse_conj(
expect("${"),
parse_scope,
expect("$}")
),
var=parse_var,
const=parse_const,
distinct=parse_distinct,
float=parse_float,
essential=parse_essential,
axiom=parse_axiom,
theorem=parse_theorem
)
)(tokens)
def parse(str: str):
tokens, result = parse_scope(tuple(preprocess(str.split())))
if tokens:
raise ParseError("No parse")
return result
if __name__ == "__main__":
v = Verifier()
f = open("ql.mm", "r")
v.verify(parse(f.read()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment