Last active
September 3, 2025 19:52
-
-
Save Frityet/fbfaa425e25392bdf8a44a7254d5e450 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- -*- mode: lua -*- | |
local macro state_machine!(name: Expression, spec: Expression, ...: Expression): Statement | |
local err = error | |
local blockmod = require("teal.block") | |
local BIDX = blockmod.BLOCK_INDEXES | |
local function expect_ident(n: any): any | |
if not n or (n.kind ~= "identifier" and n.kind ~= "variable") then | |
err("first argument must be an identifier for the machine name") | |
end | |
return n | |
end | |
local function ident(s: string): any | |
local id = block("identifier") | |
id.tk = s | |
return id | |
end | |
local function is_ident(str: string): boolean | |
if #str == 0 then return false end | |
local first = string.sub(str, 1, 1) | |
if not ((first >= 'A' and first <= 'Z') or (first >= 'a' and first <= 'z') or first == '_') then | |
return false | |
end | |
for i = 2, #str do | |
local c = string.sub(str, i, i) | |
local is_alpha = (c >= 'A' and c <= 'Z') or (c >= 'a' and c <= 'z') | |
local is_num = (c >= '0' and c <= '9') | |
if not (is_alpha or is_num or c == '_') then | |
return false | |
end | |
end | |
return true | |
end | |
local function keyname(item: any): string | |
local k = item[BIDX.LITERAL_TABLE_ITEM.KEY] | |
if k and k.conststr and k.conststr ~= '' then | |
return k.conststr | |
end | |
if k and k.tk then | |
local s = k.tk | |
if #s >= 2 and s:sub(1,1) == '"' and s:sub(#s,#s) == '"' then | |
return s:sub(2, #s - 1) | |
end | |
return s | |
end | |
if item.tk and item.tk ~= '' then | |
return item.tk | |
end | |
err("invalid table key; expected a string key") | |
end | |
local function val_string(n: any): string | |
if n and n.conststr then return n.conststr end | |
if n and n.kind == "string" and n.tk then | |
local s = n.tk | |
if #s >= 2 and s:sub(1,1) == '"' and s:sub(#s,#s) == '"' then | |
return s:sub(2, #s - 1) | |
end | |
return s | |
end | |
err("expected a string literal value") | |
end | |
local function str_lit(s: string): any | |
local n = block("string") | |
n.tk = '"' .. s .. '"' | |
n.conststr = s | |
n.is_longstring = false | |
return n | |
end | |
name = expect_ident(name) | |
local mname = name.tk | |
if not spec or spec.kind ~= "literal_table" then | |
err("second argument must be a table literal with fields 'states' and 'initial'") | |
end | |
local states: {string:boolean} = {} | |
local order: {string} = {} | |
local transitions: {string:{string:string}} = {} -- state -> { event -> target_state } | |
local initial: string = "" | |
local opts: any = nil | |
if select('#', ...) >= 1 then | |
opts = select(1, ...) | |
if opts and opts.kind == "statements" and #opts == 1 then | |
opts = opts[1] | |
end | |
if opts and opts.kind ~= "literal_table" then | |
err("third argument, if present, must be a table literal of options") | |
end | |
end | |
local function get_opt_str(key: string, def: string): string | |
if not opts then return def end | |
for _, it in ipairs(opts) do | |
if it and it.kind == "literal_table_item" then | |
local k = keyname(it) | |
if k == key then | |
local v = it[BIDX.LITERAL_TABLE_ITEM.VALUE] or it[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE] | |
return val_string(v) | |
end | |
end | |
end | |
return def | |
end | |
for _, item in ipairs(spec) do | |
if not item or item.kind ~= "literal_table_item" then | |
err("malformed spec: expected key/value pairs in spec table") | |
end | |
local k = keyname(item) | |
if k == "states" then | |
local st = item[BIDX.LITERAL_TABLE_ITEM.VALUE] or item[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE] | |
if st and st.kind == "statements" and #st == 1 then | |
st = st[1] | |
end | |
if not st or st.kind ~= "literal_table" then | |
err("'states' must be a table") | |
end | |
for _, sit in ipairs(st) do | |
if not sit or sit.kind ~= "literal_table_item" then | |
err("malformed 'states' entry") | |
end | |
local sname = keyname(sit) | |
if not is_ident(sname) then | |
err("state name '" .. sname .. "' is not a valid identifier") | |
end | |
if states[sname] then | |
err("duplicate state '" .. sname .. "'") | |
end | |
states[sname] = true | |
table.insert(order, sname) | |
transitions[sname] = {} | |
local sbody = sit[BIDX.LITERAL_TABLE_ITEM.VALUE] or sit[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE] | |
if sbody and sbody.kind == "literal_table" then | |
for _, sb in ipairs(sbody) do | |
if sb.kind ~= "literal_table_item" then | |
err("malformed state body for '" .. sname .. "'") | |
end | |
local sk = keyname(sb) | |
if sk == "on" then | |
local onmap = sb[BIDX.LITERAL_TABLE_ITEM.VALUE] or sb[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE] | |
if onmap and onmap.kind == "statements" and #onmap == 1 then | |
onmap = onmap[1] | |
end | |
if not onmap or onmap.kind ~= "literal_table" then | |
err("'on' must be a table mapping events to target states for '" .. sname .. "'") | |
end | |
for _, ev in ipairs(onmap) do | |
if ev.kind ~= "literal_table_item" then | |
err("malformed 'on' entry in state '" .. sname .. "'") | |
end | |
local evname = keyname(ev) | |
if not is_ident(evname) then | |
err("event name '" .. evname .. "' is not a valid identifier") | |
end | |
if transitions[sname][evname] then | |
err("duplicate event '" .. evname .. "' in state '" .. sname .. "'") | |
end | |
local v = ev[BIDX.LITERAL_TABLE_ITEM.VALUE] or ev[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE] | |
local target = val_string(v) | |
transitions[sname][evname] = target | |
end | |
else | |
err("unknown key '" .. sk .. "' in state '" .. sname .. "' (expected only 'on')") | |
end | |
end | |
end | |
end | |
elseif k == "initial" then | |
local v = item[BIDX.LITERAL_TABLE_ITEM.VALUE] or item[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE] | |
initial = val_string(v) | |
else | |
err("unknown top-level key '" .. k .. "' (expected 'states' or 'initial')") | |
end | |
end | |
if initial == "" then | |
err("missing required 'initial' state") | |
end | |
if not states[initial] then | |
err("initial state '" .. initial .. "' is not declared in 'states'") | |
end | |
for s, evmap in pairs(transitions) do | |
for evname, target in pairs(evmap) do | |
if not states[target] then | |
err("transition '" .. s .. "' --" .. evname .. "--> '" .. target .. "' references unknown state") | |
end | |
end | |
end | |
local M = ident(mname) | |
local state_name_tpl = get_opt_str("state_name_tpl", "%M_%S") | |
local constructor_name = get_opt_str("constructor", "new") | |
local style = get_opt_str("style", "module") -- "module" | "methods" | |
local event_name_tpl = get_opt_str("event_name_tpl", "%E_from_%S") | |
local function render_state_name(sname: string): string | |
return state_name_tpl:gsub("%%M", mname):gsub("%%S", sname) | |
end | |
local function render_event_name(ev: string, sname: string): string | |
return event_name_tpl:gsub("%%M", mname):gsub("%%S", sname):gsub("%%E", ev) | |
end | |
local out = block("statements") | |
local iface_ident = ident(mname .. "State") | |
table.insert(out, ``` | |
local interface $iface_ident | |
state: string | |
end | |
```) | |
for _, sname in ipairs(order) do | |
local state_ident = ident(render_state_name(sname)) | |
local s_lit = str_lit(sname) | |
table.insert(out, ``` | |
local type $state_ident = record is $iface_ident | |
state: string | |
metamethod __is: function(self: $iface_ident): boolean = macroexp(_self: $iface_ident): boolean | |
return _self.state == $s_lit | |
end | |
end | |
```) | |
end | |
for _, sname in ipairs(order) do | |
local proto_ident = ident(render_state_name(sname) .. "__proto") | |
local mt_ident = ident(render_state_name(sname) .. "__mt") | |
table.insert(out, ```local $proto_ident = {}```) | |
table.insert(out, ``` | |
local $mt_ident = { | |
__index = $proto_ident, | |
__tostring = function(self: self): string return self.state end | |
} | |
```) | |
end | |
local init_type = ident(render_state_name(initial)) | |
table.insert(out, ```local $M = {}```) | |
local ctor_ident = ident(constructor_name) | |
local init_str = str_lit(initial) | |
local init_mt_ident = ident(render_state_name(initial) .. "__mt") | |
table.insert(out, ``` | |
function $M.$ctor_ident(): $init_type | |
return setmetatable({ state = $init_str }, $init_mt_ident) as $init_type | |
end | |
```) | |
if style == "methods" then | |
for _, sname in ipairs(order) do | |
local state_ident = ident(render_state_name(sname)) | |
local proto_ident = ident(render_state_name(sname) .. "__proto") | |
local evmap = transitions[sname] | |
for evname, target in pairs(evmap) do | |
local target_ident = ident(render_state_name(target)) | |
local target_mt_ident = ident(render_state_name(target) .. "__mt") | |
local ev_ident = ident(evname) | |
local target_str = str_lit(target) | |
table.insert(out, ``` | |
function $proto_ident.$ev_ident(_self: $state_ident): $target_ident | |
return setmetatable({ state = $target_str }, $target_mt_ident) as $target_ident | |
end | |
```) | |
table.insert(out, ``` | |
function $state_ident.$ev_ident(_self: $state_ident): $target_ident | |
return setmetatable({ state = $target_str }, $target_mt_ident) as $target_ident | |
end | |
```) | |
end | |
end | |
else | |
for _, sname in ipairs(order) do | |
local state_ident = ident(render_state_name(sname)) | |
local evmap = transitions[sname] | |
for evname, target in pairs(evmap) do | |
local target_ident = ident(render_state_name(target)) | |
local target_mt_ident = ident(render_state_name(target) .. "__mt") | |
local fname = ident(render_event_name(evname, sname)) | |
local target_str_b = str_lit(target) | |
table.insert(out, ``` | |
function $M.$fname(__s: $state_ident): $target_ident | |
return setmetatable({ state = $target_str_b }, $target_mt_ident) as $target_ident | |
end | |
```) | |
end | |
end | |
end | |
return out | |
end | |
state_machine!(TrafficLight, { | |
states = { | |
Red = { on = { timer = "Green" } }, | |
Green = { on = { timer = "Yellow" } }, | |
Yellow = { on = { timer = "Red" } }, | |
}, | |
initial = "Red", | |
}, { | |
style = "methods", | |
constructor = "make", | |
state_name_tpl = "%S", | |
}) | |
local type Light = Red | Green | Yellow | |
local s0: Light = TrafficLight.make() | |
math.randomseed(os.time()) | |
for _ = 1, math.random(1, 10) do | |
s0 = s0:timer() | |
end | |
print(s0) | |
if s0 is Red then | |
print "Stop!" | |
elseif s0 is Green then | |
print "Go!" | |
elseif s0 is Yellow then | |
print "Slow down!" | |
end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment