Skip to content

Instantly share code, notes, and snippets.

@Frityet
Last active September 3, 2025 19:52
Show Gist options
  • Save Frityet/fbfaa425e25392bdf8a44a7254d5e450 to your computer and use it in GitHub Desktop.
Save Frityet/fbfaa425e25392bdf8a44a7254d5e450 to your computer and use it in GitHub Desktop.
-- -*- mode: lua -*-
local macro state_machine!(name: Expression, spec: Expression, ...: Expression): Statement
local err = error
local blockmod = require("teal.block")
local BIDX = blockmod.BLOCK_INDEXES
local function expect_ident(n: any): any
if not n or (n.kind ~= "identifier" and n.kind ~= "variable") then
err("first argument must be an identifier for the machine name")
end
return n
end
local function ident(s: string): any
local id = block("identifier")
id.tk = s
return id
end
local function is_ident(str: string): boolean
if #str == 0 then return false end
local first = string.sub(str, 1, 1)
if not ((first >= 'A' and first <= 'Z') or (first >= 'a' and first <= 'z') or first == '_') then
return false
end
for i = 2, #str do
local c = string.sub(str, i, i)
local is_alpha = (c >= 'A' and c <= 'Z') or (c >= 'a' and c <= 'z')
local is_num = (c >= '0' and c <= '9')
if not (is_alpha or is_num or c == '_') then
return false
end
end
return true
end
local function keyname(item: any): string
local k = item[BIDX.LITERAL_TABLE_ITEM.KEY]
if k and k.conststr and k.conststr ~= '' then
return k.conststr
end
if k and k.tk then
local s = k.tk
if #s >= 2 and s:sub(1,1) == '"' and s:sub(#s,#s) == '"' then
return s:sub(2, #s - 1)
end
return s
end
if item.tk and item.tk ~= '' then
return item.tk
end
err("invalid table key; expected a string key")
end
local function val_string(n: any): string
if n and n.conststr then return n.conststr end
if n and n.kind == "string" and n.tk then
local s = n.tk
if #s >= 2 and s:sub(1,1) == '"' and s:sub(#s,#s) == '"' then
return s:sub(2, #s - 1)
end
return s
end
err("expected a string literal value")
end
local function str_lit(s: string): any
local n = block("string")
n.tk = '"' .. s .. '"'
n.conststr = s
n.is_longstring = false
return n
end
name = expect_ident(name)
local mname = name.tk
if not spec or spec.kind ~= "literal_table" then
err("second argument must be a table literal with fields 'states' and 'initial'")
end
local states: {string:boolean} = {}
local order: {string} = {}
local transitions: {string:{string:string}} = {} -- state -> { event -> target_state }
local initial: string = ""
local opts: any = nil
if select('#', ...) >= 1 then
opts = select(1, ...)
if opts and opts.kind == "statements" and #opts == 1 then
opts = opts[1]
end
if opts and opts.kind ~= "literal_table" then
err("third argument, if present, must be a table literal of options")
end
end
local function get_opt_str(key: string, def: string): string
if not opts then return def end
for _, it in ipairs(opts) do
if it and it.kind == "literal_table_item" then
local k = keyname(it)
if k == key then
local v = it[BIDX.LITERAL_TABLE_ITEM.VALUE] or it[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE]
return val_string(v)
end
end
end
return def
end
for _, item in ipairs(spec) do
if not item or item.kind ~= "literal_table_item" then
err("malformed spec: expected key/value pairs in spec table")
end
local k = keyname(item)
if k == "states" then
local st = item[BIDX.LITERAL_TABLE_ITEM.VALUE] or item[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE]
if st and st.kind == "statements" and #st == 1 then
st = st[1]
end
if not st or st.kind ~= "literal_table" then
err("'states' must be a table")
end
for _, sit in ipairs(st) do
if not sit or sit.kind ~= "literal_table_item" then
err("malformed 'states' entry")
end
local sname = keyname(sit)
if not is_ident(sname) then
err("state name '" .. sname .. "' is not a valid identifier")
end
if states[sname] then
err("duplicate state '" .. sname .. "'")
end
states[sname] = true
table.insert(order, sname)
transitions[sname] = {}
local sbody = sit[BIDX.LITERAL_TABLE_ITEM.VALUE] or sit[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE]
if sbody and sbody.kind == "literal_table" then
for _, sb in ipairs(sbody) do
if sb.kind ~= "literal_table_item" then
err("malformed state body for '" .. sname .. "'")
end
local sk = keyname(sb)
if sk == "on" then
local onmap = sb[BIDX.LITERAL_TABLE_ITEM.VALUE] or sb[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE]
if onmap and onmap.kind == "statements" and #onmap == 1 then
onmap = onmap[1]
end
if not onmap or onmap.kind ~= "literal_table" then
err("'on' must be a table mapping events to target states for '" .. sname .. "'")
end
for _, ev in ipairs(onmap) do
if ev.kind ~= "literal_table_item" then
err("malformed 'on' entry in state '" .. sname .. "'")
end
local evname = keyname(ev)
if not is_ident(evname) then
err("event name '" .. evname .. "' is not a valid identifier")
end
if transitions[sname][evname] then
err("duplicate event '" .. evname .. "' in state '" .. sname .. "'")
end
local v = ev[BIDX.LITERAL_TABLE_ITEM.VALUE] or ev[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE]
local target = val_string(v)
transitions[sname][evname] = target
end
else
err("unknown key '" .. sk .. "' in state '" .. sname .. "' (expected only 'on')")
end
end
end
end
elseif k == "initial" then
local v = item[BIDX.LITERAL_TABLE_ITEM.VALUE] or item[BIDX.LITERAL_TABLE_ITEM.TYPED_VALUE]
initial = val_string(v)
else
err("unknown top-level key '" .. k .. "' (expected 'states' or 'initial')")
end
end
if initial == "" then
err("missing required 'initial' state")
end
if not states[initial] then
err("initial state '" .. initial .. "' is not declared in 'states'")
end
for s, evmap in pairs(transitions) do
for evname, target in pairs(evmap) do
if not states[target] then
err("transition '" .. s .. "' --" .. evname .. "--> '" .. target .. "' references unknown state")
end
end
end
local M = ident(mname)
local state_name_tpl = get_opt_str("state_name_tpl", "%M_%S")
local constructor_name = get_opt_str("constructor", "new")
local style = get_opt_str("style", "module") -- "module" | "methods"
local event_name_tpl = get_opt_str("event_name_tpl", "%E_from_%S")
local function render_state_name(sname: string): string
return state_name_tpl:gsub("%%M", mname):gsub("%%S", sname)
end
local function render_event_name(ev: string, sname: string): string
return event_name_tpl:gsub("%%M", mname):gsub("%%S", sname):gsub("%%E", ev)
end
local out = block("statements")
local iface_ident = ident(mname .. "State")
table.insert(out, ```
local interface $iface_ident
state: string
end
```)
for _, sname in ipairs(order) do
local state_ident = ident(render_state_name(sname))
local s_lit = str_lit(sname)
table.insert(out, ```
local type $state_ident = record is $iface_ident
state: string
metamethod __is: function(self: $iface_ident): boolean = macroexp(_self: $iface_ident): boolean
return _self.state == $s_lit
end
end
```)
end
for _, sname in ipairs(order) do
local proto_ident = ident(render_state_name(sname) .. "__proto")
local mt_ident = ident(render_state_name(sname) .. "__mt")
table.insert(out, ```local $proto_ident = {}```)
table.insert(out, ```
local $mt_ident = {
__index = $proto_ident,
__tostring = function(self: self): string return self.state end
}
```)
end
local init_type = ident(render_state_name(initial))
table.insert(out, ```local $M = {}```)
local ctor_ident = ident(constructor_name)
local init_str = str_lit(initial)
local init_mt_ident = ident(render_state_name(initial) .. "__mt")
table.insert(out, ```
function $M.$ctor_ident(): $init_type
return setmetatable({ state = $init_str }, $init_mt_ident) as $init_type
end
```)
if style == "methods" then
for _, sname in ipairs(order) do
local state_ident = ident(render_state_name(sname))
local proto_ident = ident(render_state_name(sname) .. "__proto")
local evmap = transitions[sname]
for evname, target in pairs(evmap) do
local target_ident = ident(render_state_name(target))
local target_mt_ident = ident(render_state_name(target) .. "__mt")
local ev_ident = ident(evname)
local target_str = str_lit(target)
table.insert(out, ```
function $proto_ident.$ev_ident(_self: $state_ident): $target_ident
return setmetatable({ state = $target_str }, $target_mt_ident) as $target_ident
end
```)
table.insert(out, ```
function $state_ident.$ev_ident(_self: $state_ident): $target_ident
return setmetatable({ state = $target_str }, $target_mt_ident) as $target_ident
end
```)
end
end
else
for _, sname in ipairs(order) do
local state_ident = ident(render_state_name(sname))
local evmap = transitions[sname]
for evname, target in pairs(evmap) do
local target_ident = ident(render_state_name(target))
local target_mt_ident = ident(render_state_name(target) .. "__mt")
local fname = ident(render_event_name(evname, sname))
local target_str_b = str_lit(target)
table.insert(out, ```
function $M.$fname(__s: $state_ident): $target_ident
return setmetatable({ state = $target_str_b }, $target_mt_ident) as $target_ident
end
```)
end
end
end
return out
end
state_machine!(TrafficLight, {
states = {
Red = { on = { timer = "Green" } },
Green = { on = { timer = "Yellow" } },
Yellow = { on = { timer = "Red" } },
},
initial = "Red",
}, {
style = "methods",
constructor = "make",
state_name_tpl = "%S",
})
local type Light = Red | Green | Yellow
local s0: Light = TrafficLight.make()
math.randomseed(os.time())
for _ = 1, math.random(1, 10) do
s0 = s0:timer()
end
print(s0)
if s0 is Red then
print "Stop!"
elseif s0 is Green then
print "Go!"
elseif s0 is Yellow then
print "Slow down!"
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment