Skip to content

Instantly share code, notes, and snippets.

@makenowjust
Last active February 20, 2025 05:51
Show Gist options
  • Save makenowjust/bfca1fc504e4780a06f4fb3ab2d710ca to your computer and use it in GitHub Desktop.
Save makenowjust/bfca1fc504e4780a06f4fb3ab2d710ca to your computer and use it in GitHub Desktop.
// This is an implementation of the Λ* algorithm in Scala 3.
//
// The Λ* algorithm is a learning algorithm for symbolic automata, proposed by
// Samuel Drews and Loris D'Antoni (2017), "Learning Symbolic Automata"
// https://doi.org/10.1007/978-3-662-54577-5_10.
import scala.collection.mutable
/** `BoolAlg` represents an effective Boolean algebra on the domain `D`.
*
* `P` is a type of predicates on the domain `D`.
*/
trait BoolAlg[D, P]:
/** Returns the predicate that is always true. */
def `true`: P
/** Returns the predicate that is always false. */
def `false`: P
/** Returns the predicate: p ∧ q. */
def and(p: P, q: P): P
/** Returns the predicate: p ∨ q. */
def or(p: P, q: P): P
/** Returns the predicate: ¬p. */
def not(p: P): P
/** Checks if the denotation of `p` contains `d`. */
def contains(p: P, d: D): Boolean
/** Computes the partitioning function to `ds`.
*
* This returns the sequence of separating predicates of `ds`.
*/
def partition(ds: Seq[Set[D]]): Seq[P]
/** `Pred` is a concrete representation of predicates on atomic proposition `A`.
*/
enum Pred[+A]:
case True, False
case Atom(a: A)
case And(p: Pred[A], q: Pred[A])
case Or(p: Pred[A], q: Pred[A])
case Not(p: Pred[A])
infix def and[AA >: A](that: Pred[AA]): Pred[AA] = (this, that) match
case (True, q) => q
case (p, True) => p
case (False, _) | (_, False) => False
case (p, q) => And(p, q)
infix def or[AA >: A](that: Pred[AA]): Pred[AA] = (this, that) match
case (True, _) | (_, True) => True
case (p, False) => p
case (False, q) => q
case (p, q) => Or(p, q)
def not: Pred[A] = this match
case True => False
case False => True
case Not(p) => p
case p => Not(p)
/** Checks if the denotation of `p` contains `d`.
*
* `atom` is a function that checks if the atomic proposition contains a data value.
*/
def contains[D](d: D)(atom: (A, D) => Boolean): Boolean = this match
case True => true
case False => false
case Atom(a) => atom(a, d)
case And(p, q) => p.contains(d)(atom) && q.contains(d)(atom)
case Or(p, q) => p.contains(d)(atom) || q.contains(d)(atom)
case Not(p) => !p.contains(d)(atom)
object Pred:
/** `EqualityAlgebra` is an instance of `BoolAlg` for the equality. */
given EqualityAlgebra[A]: BoolAlg[A, Pred[A]] with
def `true`: Pred[A] = Pred.True
def `false`: Pred[A] = Pred.False
def and(p: Pred[A], q: Pred[A]): Pred[A] = Pred.And(p, q)
def or(p: Pred[A], q: Pred[A]): Pred[A] = Pred.Or(p, q)
def not(p: Pred[A]): Pred[A] = Pred.Not(p)
def contains(p: Pred[A], d: A): Boolean = p.contains(d)(_ == _)
def partition(dss: Seq[Set[A]]): Seq[Pred[A]] =
val maxIndex = dss.zipWithIndex.maxBy(_._1.size)._2
val largePred = dss.iterator.zipWithIndex
.filter(_._2 != maxIndex)
.map(_._1)
.flatten
.map(Pred.Atom(_))
.foldLeft(Pred.False)(_ or _)
dss.zipWithIndex.map:
case (ds, i) if i == maxIndex => largePred.not
case (ds, i) =>
ds.iterator.map(Pred.Atom(_)).foldLeft(Pred.False)(_ or _)
/** `Dfa` represents a deterministic finite automaton.
*
* This is used for representing evidence automata (Def. 3).
*/
final case class Dfa[S, A](
initialState: S,
acceptStateSet: Set[S],
transitionFunction: Map[S, Map[A, S]]
):
/** Converts this evidence automaton to an SFA using the partitioning function.
*/
def toSfa[P](using P: BoolAlg[A, P]): Sfa[S, P] =
val transitionFunction = this.transitionFunction.map:
case (state, edgeMap) =>
val nextStateToChars =
edgeMap.toSeq.groupBy(_._2).view.mapValues(_.map(_._1)).toSeq
val preds = P.partition(nextStateToChars.map(_._2.toSet))
val newEdgeMap = nextStateToChars
.zip(preds)
.map:
case ((nextState, _), pred) => (pred, nextState)
.toMap
state -> newEdgeMap
Sfa(initialState, acceptStateSet, transitionFunction.toMap)
/** `Sfa` represents a symbolic finite automaton (Def. 1).
*
* In this implementation, SFAs are assumed to be deterministic and finite.
*/
final case class Sfa[S, P](
initialState: S,
acceptStateSet: Set[S],
transitionFunction: Map[S, Map[P, S]]
):
def transition[A](state: S, char: A)(using P: BoolAlg[A, P]): Option[S] =
val edgeMap = transitionFunction(state)
edgeMap.find((p, _) => P.contains(p, char)).map(_._2)
/** `Alphabet` represents an alphabet of characters. */
trait Alphabet[A]:
/** Returns the arbitrary character. */
def arbChar: A
object Alphabet:
/** Creates an instance of `Alphabet` with the given character. */
def apply[A](char: A): Alphabet[A] = new Alphabet[A]:
def arbChar = char
/** `Oracle` represents an oracle that provides membership and equivalence queries. */
trait Oracle[A]:
/** Checks if the given word is in the target language. */
def membershipQuery(word: Seq[A]): Boolean
/** Checks if the given SFA is equivalent to the target language.
*
* This returns a counterexample if the given SFA is not equivalent to the target language.
*/
def equivalenceQuery[P](sfa: Sfa[?, P])(using BoolAlg[A, P]): Option[Seq[A]]
object Oracle:
/** Creates an instance of `Oracle` with the given function. */
def fromFunction[A](
finiteAlphabet: Set[A],
minWordLength: Int = 10,
maxWordLength: Int = 100,
numWords: Int = 100,
randomSeed: Long = 0L
)(f: Seq[A] => Boolean): Oracle[A] =
val alphabetIndexedSeq = finiteAlphabet.toIndexedSeq
new Oracle[A]:
def membershipQuery(word: Seq[A]): Boolean = f(word)
def equivalenceQuery[P](sfa: Sfa[?, P])(using BoolAlg[A, P]): Option[Seq[A]] =
val rand = util.Random(randomSeed)
util.boundary:
for i <- 0 until numWords do
val size = rand.between(minWordLength, maxWordLength + 1)
var word = Seq.empty[A]
var state = sfa.initialState
for j <- 0 until size do
val char = alphabetIndexedSeq(rand.nextInt(alphabetIndexedSeq.size))
word :+= char
state = sfa.transition(state, char).get
if sfa.acceptStateSet.contains(state) != f(word) then
util.boundary.break(Some(word))
None
/** `Prefix` is a prefix of a word. */
type Prefix[A] = Seq[A]
/** `Suffix` is a suffix of a word. */
type Suffix[A] = Seq[A]
/** `Sig` is a signature of a prefix.
*
* A signature is a sequence of booleans; for a prefix `s`, `sig(i)` is true
* if `s ++ suffices(i)` is in the target language, otherwise false.
*/
type Sig = Seq[Boolean]
/** `ObservationTable` represents an observation table (Def. 2).
*
* In this implementation, an observation table `(S, R, E, f)` is represented as
* four values `prefixSet, boundarySet, suffices` and `rowMap`, where:
*
* - `prefixSet` is corresponding to `S`,
* - `boundarySet` is corresponding to `R`,
* - `suffices` is corresponding to `E`, and
* - `rowMap` is corresponding to `f`.
*/
final case class ObservationTable[A](
prefixSet: Set[Prefix[A]],
boundarySet: Set[Prefix[A]],
suffices: Seq[Suffix[A]],
rowMap: Map[Prefix[A], Sig]
):
/** Returns a row of the given prefix.
*
* A row value is the signature of the prefix, and it returns the pre-computed value.
*/
private def row(prefix: Prefix[A]): Sig = rowMap(prefix)
/** Computes the signature of the given prefix. */
private def sig(prefix: Prefix[A])(using O: Oracle[A]): Sig =
suffices.map(suffix => O.membershipQuery(prefix ++ suffix))
/** Returns the set of prefixes and boundaries. */
private def prefixAndBoundarySet: Set[Prefix[A]] =
prefixSet ++ boundarySet
/** Returns the set of extensions of the given prefix.
*
* An extension is a word that is obtained by appending a character to the prefix.
*/
private def extensionSet(prefix: Prefix[A]): Set[Prefix[A]] =
prefixSet.filter(b => prefix.size + 1 == b.size && b.startsWith(prefix)) ++
boundarySet.filter(b => prefix.size + 1 == b.size && b.startsWith(prefix))
/** Finds and returns an unclosed boundary if it exists. */
def findUnclosedBoundary(): Option[Prefix[A]] =
val candidates = boundarySet.filter: boundary =>
!prefixSet.exists(row(_) == row(boundary))
if candidates.nonEmpty then Some(candidates.minBy(_.size))
else None
/** Does the "close" operation (§3.3) with the given prefix. */
def close(prefix: Prefix[A])(using A: Alphabet[A], O: Oracle[A]): ObservationTable[A] =
val newPrefixSet = prefixSet + prefix
val newBoundary = prefix :+ A.arbChar
val newBoundarySet = boundarySet - prefix + newBoundary
val newRowMap =
if rowMap.contains(newBoundary) then rowMap
else rowMap + (newBoundary -> sig(newBoundary))
ObservationTable(newPrefixSet, newBoundarySet, suffices, newRowMap)
/** Finds and return a sequence of evidence-unclosed words if it exists. */
def findEvidenceUnclosedWord(): Seq[Seq[A]] =
val prefixAndBoundarySet = this.prefixAndBoundarySet
prefixSet.iterator
.flatMap: prefix =>
suffices.iterator
.map(prefix ++ _)
.filterNot(prefixAndBoundarySet)
.toSeq
/** Does the "evidence-close" operation (§3.3) with the given word. */
def evidenceClose(word: Seq[A])(using O: Oracle[A]): ObservationTable[A] =
if prefixSet.contains(word) then this
else
var newBoundarySet = boundarySet
var newRowMap = rowMap
for newBoundary <- word.inits; if !boundarySet.contains(newBoundary) do
newBoundarySet += newBoundary
newRowMap += newBoundary -> sig(newBoundary)
ObservationTable(prefixSet, newBoundarySet, suffices, newRowMap)
/** Finds and returns an inconsistent suffix if it exists. */
def findInconsistentSuffix(): Option[Suffix[A]] =
val prefixAndBoundarySet = this.prefixAndBoundarySet
val triple = (
for
prefix1 <- prefixAndBoundarySet.iterator
ext1 <- extensionSet(prefix1).iterator
char = ext1.last
prefix2 <- prefixAndBoundarySet.iterator
ext2 = prefix2 :+ char
if prefixAndBoundarySet.contains(ext2)
if row(prefix1) == row(prefix2)
if row(ext1) != row(ext2)
yield (prefix1, prefix2, char)
).nextOption()
triple.map: (prefix1, prefix2, char) =>
val ext1 = prefix1 :+ char
val ext2 = prefix2 :+ char
val index = row(ext1)
.zip(row(ext2))
.zipWithIndex
.find:
case ((b1, b2), i) => b1 != b2
.map(_._2)
.get
char +: suffices(index)
/** Does the "make-consistent" and "distribute" operations (§3.3) with the given suffix. */
def makeConsistentAndDistribute(newSuffix: Suffix[A])(using O: Oracle[A]): ObservationTable[A] =
import O.{membershipQuery => MQ}
val prefixAndBoundarySet = this.prefixAndBoundarySet
var boundarySetToAdd = Set.empty[Prefix[A]]
// First, the learner do "distribute" operation.
for
prefix1 <- prefixAndBoundarySet
prefix2 <- prefixAndBoundarySet
if row(prefix1) == row(prefix2)
if MQ(prefix1 ++ newSuffix) != MQ(prefix2 ++ newSuffix)
do
for
char <- extensionSet(prefix1).map(_.last)
if !prefixSet.contains(prefix2 :+ char)
do boundarySetToAdd += prefix2 :+ char
for
char <- extensionSet(prefix2).map(_.last)
if !prefixSet.contains(prefix1 :+ char)
do boundarySetToAdd += prefix1 :+ char
var newRowMap = rowMap
val newBoundarySet = boundarySet ++ boundarySetToAdd
for newBoundary <- boundarySetToAdd; if !rowMap.contains(newBoundary) do
newRowMap += newBoundary -> sig(newBoundary)
// Next, the learner do "make-consistent" operation.
val newSuffices = suffices :+ newSuffix
for prefix <- newRowMap.keys do newRowMap += prefix -> (newRowMap(prefix) :+ MQ(prefix ++ newSuffix))
ObservationTable(prefixSet, newBoundarySet, newSuffices, newRowMap)
/** Processes the given counterexample. */
def process(cex: Seq[A])(using O: Oracle[A]): ObservationTable[A] =
val prefixAndBoundarySet = this.prefixAndBoundarySet
var newBoundarySet = boundarySet
var newRowMap = rowMap
for boundary <- cex.inits; if !prefixAndBoundarySet.contains(boundary) do
newBoundarySet += boundary
newRowMap += boundary -> sig(boundary)
ObservationTable(prefixSet, newBoundarySet, suffices, newRowMap)
/** Builds an evidence automaton from the observation table (§3.1). */
def buildEvidence(): Dfa[Prefix[A], A] =
val prefixAndBoundarySet = this.prefixAndBoundarySet
val rowToPrefix = rowMap.view.filterKeys(prefixSet).map(_.swap).toMap
val transitionFunction = mutable.Map.empty[Prefix[A], Map[A, Prefix[A]]]
for
prefix0 <- prefixAndBoundarySet
ext0 <- extensionSet(prefix0)
do
val char = ext0.last
val prefix = rowToPrefix(row(prefix0))
val boundary = rowToPrefix(row(ext0))
if !transitionFunction.contains(prefix) then
transitionFunction(prefix) = Map.empty
transitionFunction(prefix) += char -> boundary
val initialState = rowToPrefix(rowMap(Seq.empty))
val acceptSet = prefixSet.map(row).filter(_.head).map(rowToPrefix)
Dfa(initialState, acceptSet, transitionFunction.toMap)
object ObservationTable:
/** Creates an empty observation table. */
def empty[A](using A: Alphabet[A], O: Oracle[A]): ObservationTable[A] =
import O.{membershipQuery => MQ}
val prefixSet = Set(Seq.empty[A])
val boundarySet = Set(Seq(A.arbChar))
val suffices = Seq(Seq.empty)
val rowMap = Map(
Seq.empty -> Seq(MQ(Seq.empty)),
Seq(A.arbChar) -> Seq(MQ(Seq(A.arbChar)))
)
ObservationTable(prefixSet, boundarySet, suffices, rowMap)
/** `Learner` provides an implementation of the Λ* algorithm. */
object Learner:
/** Infers an SFA from the given alphabet and oracle using the Λ* algorithm. */
def learn[A, P](
alphabet: Alphabet[A],
oracle: Oracle[A],
)(using P: BoolAlg[A, P]): Sfa[Prefix[A], P] =
import oracle.{equivalenceQuery => EQ}
given Alphabet[A] = alphabet
given Oracle[A] = oracle
var obs = ObservationTable.empty
var result = Option.empty[Sfa[Prefix[A], P]]
while result.isEmpty do
util.boundary:
println(s"obs = $obs")
for boundary <- obs.findUnclosedBoundary() do
println(s"close($boundary)")
obs = obs.close(boundary)
util.boundary.break()
val words = obs.findEvidenceUnclosedWord()
if words.nonEmpty then
for word <- words do
println(s"evidenceClose($word)")
obs = obs.evidenceClose(word)
util.boundary.break()
for suffix <- obs.findInconsistentSuffix() do
println(s"makeConsistentAndDistribute($suffix)")
obs = obs.makeConsistentAndDistribute(suffix)
util.boundary.break()
val dfa = obs.buildEvidence()
val sfa = dfa.toSfa
println(s"EQ($sfa)")
EQ(sfa) match
case Some(cex) =>
println(s"process($cex)")
obs = obs.process(cex)
case None =>
result = Some(sfa)
result.get
val alphabet = Alphabet(0)
val oracle = Oracle.fromFunction(Set(0, 1, 2, 3)): word =>
word.count(_ == 0) % 3 == 0 && word.count(_ == 1) % 2 == 0
val sfa = Learner.learn(alphabet, oracle)(using Pred.EqualityAlgebra[Int])
println(sfa)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment