Last active
February 20, 2025 05:51
-
-
Save makenowjust/bfca1fc504e4780a06f4fb3ab2d710ca to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is an implementation of the Λ* algorithm in Scala 3. | |
// | |
// The Λ* algorithm is a learning algorithm for symbolic automata, proposed by | |
// Samuel Drews and Loris D'Antoni (2017), "Learning Symbolic Automata" | |
// https://doi.org/10.1007/978-3-662-54577-5_10. | |
import scala.collection.mutable | |
/** `BoolAlg` represents an effective Boolean algebra on the domain `D`. | |
* | |
* `P` is a type of predicates on the domain `D`. | |
*/ | |
trait BoolAlg[D, P]: | |
/** Returns the predicate that is always true. */ | |
def `true`: P | |
/** Returns the predicate that is always false. */ | |
def `false`: P | |
/** Returns the predicate: p ∧ q. */ | |
def and(p: P, q: P): P | |
/** Returns the predicate: p ∨ q. */ | |
def or(p: P, q: P): P | |
/** Returns the predicate: ¬p. */ | |
def not(p: P): P | |
/** Checks if the denotation of `p` contains `d`. */ | |
def contains(p: P, d: D): Boolean | |
/** Computes the partitioning function to `ds`. | |
* | |
* This returns the sequence of separating predicates of `ds`. | |
*/ | |
def partition(ds: Seq[Set[D]]): Seq[P] | |
/** `Pred` is a concrete representation of predicates on atomic proposition `A`. | |
*/ | |
enum Pred[+A]: | |
case True, False | |
case Atom(a: A) | |
case And(p: Pred[A], q: Pred[A]) | |
case Or(p: Pred[A], q: Pred[A]) | |
case Not(p: Pred[A]) | |
infix def and[AA >: A](that: Pred[AA]): Pred[AA] = (this, that) match | |
case (True, q) => q | |
case (p, True) => p | |
case (False, _) | (_, False) => False | |
case (p, q) => And(p, q) | |
infix def or[AA >: A](that: Pred[AA]): Pred[AA] = (this, that) match | |
case (True, _) | (_, True) => True | |
case (p, False) => p | |
case (False, q) => q | |
case (p, q) => Or(p, q) | |
def not: Pred[A] = this match | |
case True => False | |
case False => True | |
case Not(p) => p | |
case p => Not(p) | |
/** Checks if the denotation of `p` contains `d`. | |
* | |
* `atom` is a function that checks if the atomic proposition contains a data value. | |
*/ | |
def contains[D](d: D)(atom: (A, D) => Boolean): Boolean = this match | |
case True => true | |
case False => false | |
case Atom(a) => atom(a, d) | |
case And(p, q) => p.contains(d)(atom) && q.contains(d)(atom) | |
case Or(p, q) => p.contains(d)(atom) || q.contains(d)(atom) | |
case Not(p) => !p.contains(d)(atom) | |
object Pred: | |
/** `EqualityAlgebra` is an instance of `BoolAlg` for the equality. */ | |
given EqualityAlgebra[A]: BoolAlg[A, Pred[A]] with | |
def `true`: Pred[A] = Pred.True | |
def `false`: Pred[A] = Pred.False | |
def and(p: Pred[A], q: Pred[A]): Pred[A] = Pred.And(p, q) | |
def or(p: Pred[A], q: Pred[A]): Pred[A] = Pred.Or(p, q) | |
def not(p: Pred[A]): Pred[A] = Pred.Not(p) | |
def contains(p: Pred[A], d: A): Boolean = p.contains(d)(_ == _) | |
def partition(dss: Seq[Set[A]]): Seq[Pred[A]] = | |
val maxIndex = dss.zipWithIndex.maxBy(_._1.size)._2 | |
val largePred = dss.iterator.zipWithIndex | |
.filter(_._2 != maxIndex) | |
.map(_._1) | |
.flatten | |
.map(Pred.Atom(_)) | |
.foldLeft(Pred.False)(_ or _) | |
dss.zipWithIndex.map: | |
case (ds, i) if i == maxIndex => largePred.not | |
case (ds, i) => | |
ds.iterator.map(Pred.Atom(_)).foldLeft(Pred.False)(_ or _) | |
/** `Dfa` represents a deterministic finite automaton. | |
* | |
* This is used for representing evidence automata (Def. 3). | |
*/ | |
final case class Dfa[S, A]( | |
initialState: S, | |
acceptStateSet: Set[S], | |
transitionFunction: Map[S, Map[A, S]] | |
): | |
/** Converts this evidence automaton to an SFA using the partitioning function. | |
*/ | |
def toSfa[P](using P: BoolAlg[A, P]): Sfa[S, P] = | |
val transitionFunction = this.transitionFunction.map: | |
case (state, edgeMap) => | |
val nextStateToChars = | |
edgeMap.toSeq.groupBy(_._2).view.mapValues(_.map(_._1)).toSeq | |
val preds = P.partition(nextStateToChars.map(_._2.toSet)) | |
val newEdgeMap = nextStateToChars | |
.zip(preds) | |
.map: | |
case ((nextState, _), pred) => (pred, nextState) | |
.toMap | |
state -> newEdgeMap | |
Sfa(initialState, acceptStateSet, transitionFunction.toMap) | |
/** `Sfa` represents a symbolic finite automaton (Def. 1). | |
* | |
* In this implementation, SFAs are assumed to be deterministic and finite. | |
*/ | |
final case class Sfa[S, P]( | |
initialState: S, | |
acceptStateSet: Set[S], | |
transitionFunction: Map[S, Map[P, S]] | |
): | |
def transition[A](state: S, char: A)(using P: BoolAlg[A, P]): Option[S] = | |
val edgeMap = transitionFunction(state) | |
edgeMap.find((p, _) => P.contains(p, char)).map(_._2) | |
/** `Alphabet` represents an alphabet of characters. */ | |
trait Alphabet[A]: | |
/** Returns the arbitrary character. */ | |
def arbChar: A | |
object Alphabet: | |
/** Creates an instance of `Alphabet` with the given character. */ | |
def apply[A](char: A): Alphabet[A] = new Alphabet[A]: | |
def arbChar = char | |
/** `Oracle` represents an oracle that provides membership and equivalence queries. */ | |
trait Oracle[A]: | |
/** Checks if the given word is in the target language. */ | |
def membershipQuery(word: Seq[A]): Boolean | |
/** Checks if the given SFA is equivalent to the target language. | |
* | |
* This returns a counterexample if the given SFA is not equivalent to the target language. | |
*/ | |
def equivalenceQuery[P](sfa: Sfa[?, P])(using BoolAlg[A, P]): Option[Seq[A]] | |
object Oracle: | |
/** Creates an instance of `Oracle` with the given function. */ | |
def fromFunction[A]( | |
finiteAlphabet: Set[A], | |
minWordLength: Int = 10, | |
maxWordLength: Int = 100, | |
numWords: Int = 100, | |
randomSeed: Long = 0L | |
)(f: Seq[A] => Boolean): Oracle[A] = | |
val alphabetIndexedSeq = finiteAlphabet.toIndexedSeq | |
new Oracle[A]: | |
def membershipQuery(word: Seq[A]): Boolean = f(word) | |
def equivalenceQuery[P](sfa: Sfa[?, P])(using BoolAlg[A, P]): Option[Seq[A]] = | |
val rand = util.Random(randomSeed) | |
util.boundary: | |
for i <- 0 until numWords do | |
val size = rand.between(minWordLength, maxWordLength + 1) | |
var word = Seq.empty[A] | |
var state = sfa.initialState | |
for j <- 0 until size do | |
val char = alphabetIndexedSeq(rand.nextInt(alphabetIndexedSeq.size)) | |
word :+= char | |
state = sfa.transition(state, char).get | |
if sfa.acceptStateSet.contains(state) != f(word) then | |
util.boundary.break(Some(word)) | |
None | |
/** `Prefix` is a prefix of a word. */ | |
type Prefix[A] = Seq[A] | |
/** `Suffix` is a suffix of a word. */ | |
type Suffix[A] = Seq[A] | |
/** `Sig` is a signature of a prefix. | |
* | |
* A signature is a sequence of booleans; for a prefix `s`, `sig(i)` is true | |
* if `s ++ suffices(i)` is in the target language, otherwise false. | |
*/ | |
type Sig = Seq[Boolean] | |
/** `ObservationTable` represents an observation table (Def. 2). | |
* | |
* In this implementation, an observation table `(S, R, E, f)` is represented as | |
* four values `prefixSet, boundarySet, suffices` and `rowMap`, where: | |
* | |
* - `prefixSet` is corresponding to `S`, | |
* - `boundarySet` is corresponding to `R`, | |
* - `suffices` is corresponding to `E`, and | |
* - `rowMap` is corresponding to `f`. | |
*/ | |
final case class ObservationTable[A]( | |
prefixSet: Set[Prefix[A]], | |
boundarySet: Set[Prefix[A]], | |
suffices: Seq[Suffix[A]], | |
rowMap: Map[Prefix[A], Sig] | |
): | |
/** Returns a row of the given prefix. | |
* | |
* A row value is the signature of the prefix, and it returns the pre-computed value. | |
*/ | |
private def row(prefix: Prefix[A]): Sig = rowMap(prefix) | |
/** Computes the signature of the given prefix. */ | |
private def sig(prefix: Prefix[A])(using O: Oracle[A]): Sig = | |
suffices.map(suffix => O.membershipQuery(prefix ++ suffix)) | |
/** Returns the set of prefixes and boundaries. */ | |
private def prefixAndBoundarySet: Set[Prefix[A]] = | |
prefixSet ++ boundarySet | |
/** Returns the set of extensions of the given prefix. | |
* | |
* An extension is a word that is obtained by appending a character to the prefix. | |
*/ | |
private def extensionSet(prefix: Prefix[A]): Set[Prefix[A]] = | |
prefixSet.filter(b => prefix.size + 1 == b.size && b.startsWith(prefix)) ++ | |
boundarySet.filter(b => prefix.size + 1 == b.size && b.startsWith(prefix)) | |
/** Finds and returns an unclosed boundary if it exists. */ | |
def findUnclosedBoundary(): Option[Prefix[A]] = | |
val candidates = boundarySet.filter: boundary => | |
!prefixSet.exists(row(_) == row(boundary)) | |
if candidates.nonEmpty then Some(candidates.minBy(_.size)) | |
else None | |
/** Does the "close" operation (§3.3) with the given prefix. */ | |
def close(prefix: Prefix[A])(using A: Alphabet[A], O: Oracle[A]): ObservationTable[A] = | |
val newPrefixSet = prefixSet + prefix | |
val newBoundary = prefix :+ A.arbChar | |
val newBoundarySet = boundarySet - prefix + newBoundary | |
val newRowMap = | |
if rowMap.contains(newBoundary) then rowMap | |
else rowMap + (newBoundary -> sig(newBoundary)) | |
ObservationTable(newPrefixSet, newBoundarySet, suffices, newRowMap) | |
/** Finds and return a sequence of evidence-unclosed words if it exists. */ | |
def findEvidenceUnclosedWord(): Seq[Seq[A]] = | |
val prefixAndBoundarySet = this.prefixAndBoundarySet | |
prefixSet.iterator | |
.flatMap: prefix => | |
suffices.iterator | |
.map(prefix ++ _) | |
.filterNot(prefixAndBoundarySet) | |
.toSeq | |
/** Does the "evidence-close" operation (§3.3) with the given word. */ | |
def evidenceClose(word: Seq[A])(using O: Oracle[A]): ObservationTable[A] = | |
if prefixSet.contains(word) then this | |
else | |
var newBoundarySet = boundarySet | |
var newRowMap = rowMap | |
for newBoundary <- word.inits; if !boundarySet.contains(newBoundary) do | |
newBoundarySet += newBoundary | |
newRowMap += newBoundary -> sig(newBoundary) | |
ObservationTable(prefixSet, newBoundarySet, suffices, newRowMap) | |
/** Finds and returns an inconsistent suffix if it exists. */ | |
def findInconsistentSuffix(): Option[Suffix[A]] = | |
val prefixAndBoundarySet = this.prefixAndBoundarySet | |
val triple = ( | |
for | |
prefix1 <- prefixAndBoundarySet.iterator | |
ext1 <- extensionSet(prefix1).iterator | |
char = ext1.last | |
prefix2 <- prefixAndBoundarySet.iterator | |
ext2 = prefix2 :+ char | |
if prefixAndBoundarySet.contains(ext2) | |
if row(prefix1) == row(prefix2) | |
if row(ext1) != row(ext2) | |
yield (prefix1, prefix2, char) | |
).nextOption() | |
triple.map: (prefix1, prefix2, char) => | |
val ext1 = prefix1 :+ char | |
val ext2 = prefix2 :+ char | |
val index = row(ext1) | |
.zip(row(ext2)) | |
.zipWithIndex | |
.find: | |
case ((b1, b2), i) => b1 != b2 | |
.map(_._2) | |
.get | |
char +: suffices(index) | |
/** Does the "make-consistent" and "distribute" operations (§3.3) with the given suffix. */ | |
def makeConsistentAndDistribute(newSuffix: Suffix[A])(using O: Oracle[A]): ObservationTable[A] = | |
import O.{membershipQuery => MQ} | |
val prefixAndBoundarySet = this.prefixAndBoundarySet | |
var boundarySetToAdd = Set.empty[Prefix[A]] | |
// First, the learner do "distribute" operation. | |
for | |
prefix1 <- prefixAndBoundarySet | |
prefix2 <- prefixAndBoundarySet | |
if row(prefix1) == row(prefix2) | |
if MQ(prefix1 ++ newSuffix) != MQ(prefix2 ++ newSuffix) | |
do | |
for | |
char <- extensionSet(prefix1).map(_.last) | |
if !prefixSet.contains(prefix2 :+ char) | |
do boundarySetToAdd += prefix2 :+ char | |
for | |
char <- extensionSet(prefix2).map(_.last) | |
if !prefixSet.contains(prefix1 :+ char) | |
do boundarySetToAdd += prefix1 :+ char | |
var newRowMap = rowMap | |
val newBoundarySet = boundarySet ++ boundarySetToAdd | |
for newBoundary <- boundarySetToAdd; if !rowMap.contains(newBoundary) do | |
newRowMap += newBoundary -> sig(newBoundary) | |
// Next, the learner do "make-consistent" operation. | |
val newSuffices = suffices :+ newSuffix | |
for prefix <- newRowMap.keys do newRowMap += prefix -> (newRowMap(prefix) :+ MQ(prefix ++ newSuffix)) | |
ObservationTable(prefixSet, newBoundarySet, newSuffices, newRowMap) | |
/** Processes the given counterexample. */ | |
def process(cex: Seq[A])(using O: Oracle[A]): ObservationTable[A] = | |
val prefixAndBoundarySet = this.prefixAndBoundarySet | |
var newBoundarySet = boundarySet | |
var newRowMap = rowMap | |
for boundary <- cex.inits; if !prefixAndBoundarySet.contains(boundary) do | |
newBoundarySet += boundary | |
newRowMap += boundary -> sig(boundary) | |
ObservationTable(prefixSet, newBoundarySet, suffices, newRowMap) | |
/** Builds an evidence automaton from the observation table (§3.1). */ | |
def buildEvidence(): Dfa[Prefix[A], A] = | |
val prefixAndBoundarySet = this.prefixAndBoundarySet | |
val rowToPrefix = rowMap.view.filterKeys(prefixSet).map(_.swap).toMap | |
val transitionFunction = mutable.Map.empty[Prefix[A], Map[A, Prefix[A]]] | |
for | |
prefix0 <- prefixAndBoundarySet | |
ext0 <- extensionSet(prefix0) | |
do | |
val char = ext0.last | |
val prefix = rowToPrefix(row(prefix0)) | |
val boundary = rowToPrefix(row(ext0)) | |
if !transitionFunction.contains(prefix) then | |
transitionFunction(prefix) = Map.empty | |
transitionFunction(prefix) += char -> boundary | |
val initialState = rowToPrefix(rowMap(Seq.empty)) | |
val acceptSet = prefixSet.map(row).filter(_.head).map(rowToPrefix) | |
Dfa(initialState, acceptSet, transitionFunction.toMap) | |
object ObservationTable: | |
/** Creates an empty observation table. */ | |
def empty[A](using A: Alphabet[A], O: Oracle[A]): ObservationTable[A] = | |
import O.{membershipQuery => MQ} | |
val prefixSet = Set(Seq.empty[A]) | |
val boundarySet = Set(Seq(A.arbChar)) | |
val suffices = Seq(Seq.empty) | |
val rowMap = Map( | |
Seq.empty -> Seq(MQ(Seq.empty)), | |
Seq(A.arbChar) -> Seq(MQ(Seq(A.arbChar))) | |
) | |
ObservationTable(prefixSet, boundarySet, suffices, rowMap) | |
/** `Learner` provides an implementation of the Λ* algorithm. */ | |
object Learner: | |
/** Infers an SFA from the given alphabet and oracle using the Λ* algorithm. */ | |
def learn[A, P]( | |
alphabet: Alphabet[A], | |
oracle: Oracle[A], | |
)(using P: BoolAlg[A, P]): Sfa[Prefix[A], P] = | |
import oracle.{equivalenceQuery => EQ} | |
given Alphabet[A] = alphabet | |
given Oracle[A] = oracle | |
var obs = ObservationTable.empty | |
var result = Option.empty[Sfa[Prefix[A], P]] | |
while result.isEmpty do | |
util.boundary: | |
println(s"obs = $obs") | |
for boundary <- obs.findUnclosedBoundary() do | |
println(s"close($boundary)") | |
obs = obs.close(boundary) | |
util.boundary.break() | |
val words = obs.findEvidenceUnclosedWord() | |
if words.nonEmpty then | |
for word <- words do | |
println(s"evidenceClose($word)") | |
obs = obs.evidenceClose(word) | |
util.boundary.break() | |
for suffix <- obs.findInconsistentSuffix() do | |
println(s"makeConsistentAndDistribute($suffix)") | |
obs = obs.makeConsistentAndDistribute(suffix) | |
util.boundary.break() | |
val dfa = obs.buildEvidence() | |
val sfa = dfa.toSfa | |
println(s"EQ($sfa)") | |
EQ(sfa) match | |
case Some(cex) => | |
println(s"process($cex)") | |
obs = obs.process(cex) | |
case None => | |
result = Some(sfa) | |
result.get | |
val alphabet = Alphabet(0) | |
val oracle = Oracle.fromFunction(Set(0, 1, 2, 3)): word => | |
word.count(_ == 0) % 3 == 0 && word.count(_ == 1) % 2 == 0 | |
val sfa = Learner.learn(alphabet, oracle)(using Pred.EqualityAlgebra[Int]) | |
println(sfa) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment