Skip to content

Instantly share code, notes, and snippets.

@oantolin
Created January 17, 2021 17:46
Show Gist options
  • Save oantolin/610d0f158f2f8d48eb78e2e1589ea916 to your computer and use it in GitHub Desktop.
Save oantolin/610d0f158f2f8d48eb78e2e1589ea916 to your computer and use it in GitHub Desktop.
Lisp dialect that compiles to Lua: no documentation, no error reporting
syntax = {}
for _, kind in ipairs {"vector", "table"} do
local mt = {}
syntax[kind] = function (x) return setmetatable(x, mt) end
syntax[kind .. "_p"] = function (x) return type(x)=="table" and getmetatable(x)==mt
end
end
function syntax.form_p(x)
return type(x)=="table" and not syntax.vector_p(x) and not syntax.table_p(x)
end
function syntax.atom_p(x) return type(x)~="table" end
function syntax.neg_p(x) return type(x)=="number" and x<0 end
function syntax.label_p(x) return type(x)=="string" and x:match("^::.*::$") end
local function has_head_p(hds)
return function (x)
return syntax.form_p(x) and syntax[hds][x[1]]
end
end
syntax.operexpr_p = has_head_p "operators"
syntax.special_expr_p = has_head_p "special_expr"
syntax.special_stat_p = has_head_p "special_stat"
syntax.special_last_p = has_head_p "special_last"
syntax.macro_p = has_head_p "macros"
-- parsing
function syntax.parse(src)
local pos = 1
local last = src:len()+1
local function skipws()
while true do
pos = src:find("[^%s]", pos) or last
if src:match("^;", pos) then
pos = src:find("\n", pos) or last
else
break
end
end
end
local function at(pat)
local m = src:match("^" .. pat, pos)
if m then
pos = pos + m:len()
return m
end
end
local function seq(open, item, close)
if at(open) then
local l = {}
while true do
skipws()
if at(close) then break end
if at("$") then error(string.format("missing '%s'", close)) end
local i = item()
if i == nil then return end
l[#l+1] = i
end
return l
end
end
local subs = {
["!"] = "_b", ["?"] = "_p", ["*"] = "_s", ["/"] = "_d",
["<"] = "_l", [">"] = "_r", ["+"] = "_a", ["-"] = "_",
["="] = "_e"
}
local function clean(id)
local s = syntax
if s.operators[id] or s.special_stat[id] or s.special_expr[id] then
return id
end
for k, v in pairs(subs) do
id = id:gsub(k, v)
end
return id
end
local function maybe(f,x) if x~=nil then return f(x) end end
local function ident() return maybe(clean, at("[^%s%(%)%[%]%{%}%\"]+")) end
local function number()
return maybe(tonumber, at("%-?%d+%.?%d*[eE]?%-?%d*"))
end
local function str()
if at('"') then
local p = pos
while true do
local q = src:find('"', p)
if not q then return end
local i = 1
while src:sub(q-i,q-i) == "\\" do i=i+1 end
if i%2 == 0 then
p = q + 1
else
local s = src:sub(pos-1, q)
pos = q + 1
return s
end
end
end
end
local function oneof(cases)
return function ()
skipws()
for _, parser in ipairs(cases) do
local m = parser()
if m then return m end
end
end
end
local expr
local function quote()
if at("`") then
return {"quasiquote", expr()}
end
if at(",") then
return {"unquote", expr()}
end
end
local function tabulate(x)
local r = syntax.table {}
for i = 1, #x, 2 do
r[x[i]] = x[i+1]
end
return r
end
expr = oneof {
number, quote, ident, str,
function() return seq("%(", expr, "%)") end,
function() return maybe(syntax.vector, seq("%[", expr, "%]")) end,
function() return maybe(tabulate, seq("%{", expr, "%}")) end
}
return seq("", expr, "$")
end
-- utilities
table.pack = table.pack or function (...) return {...} end
table.unpack = table.unpack or unpack
local function rest_unpacked(t, d) return select(d and d+1 or 2, table.unpack(t)) end
local function rest(t,d) return table.pack(rest_unpacked(t,d)) end
local function seplast(...)
local c = table.pack(...)
local l = c[#c]
c[#c] = nil
return c, l
end
-- macros
syntax.global_macros = {}
syntax.macros = syntax.global_macros
function syntax.macroexpand(form)
if syntax.form_p(form) then
if type(form[1])=="string" then
local m = syntax.macros[form[1]]
if m then
return syntax.macroexpand(m(rest_unpacked(form)))
end
end
local r = {}
for i=1,#form do
r[i] = syntax.macroexpand(form[i])
end
return r
else
return form
end
end
function syntax.macros.quasiquote(form)
local function qq(form, level)
if syntax.form_p(form) then
if form[1]=="unquote" then
if level == 1 then
return form[2]
else
return {"unquote", qq(form[2],level-1)}
end
elseif form[1]=="quasiquote" then
return {"quasiquote", qq(form[2],level+1)}
end
end
if syntax.atom_p(form) then
return type(form)=="string" and string.format("%q", form) or form
end
if syntax.table_p(form) then
local r = syntax.table {}
for k, v in pairs(form) do
r[qq(k,level)] = qq(v,level)
end
return {"syntax.table", r}
end
local r = syntax.vector {}
for i=1,#form do
r[i] = qq(form[i],level)
end
return syntax.vector_p(form) and {"syntax.vector", r} or r
end
return qq(form,1)
end
local gensymcnt = 0
function gensym(x)
gensymcnt = gensymcnt + 1
return (x or "") .. "__gensym__" .. gensymcnt
end
-- compiling
local compile = {}
local pieces = {}
local function emit(s) pieces[#pieces + 1] = s end
local function clear() pieces = {} end
local function code() return table.concat(pieces) end
local function emitter(s) return function (ignore) emit(s) end end
local function compiler_seq(compile_each, separator, compile_last)
compile_last = compile_last or compile_each
return function(seq)
if #seq==0 then return end
for i=1,#seq-1 do
compile_each(seq[i])
emit(separator)
end
compile_last(seq[#seq])
end
end
local function compiler_alt(alts) -- alts must be disjoint
return function(ast)
for test, run in pairs(alts) do
if test~="default" and syntax[test](ast) then
run(ast)
return
end
end
alts.default(ast)
end
end
function compile.vector(ast)
emit("{")
compile.exprlist(ast)
emit("}")
end
function compile.table(ast)
emit("{")
local comma = false
for k, v in pairs(ast) do
emit(comma and ", [" or "[")
compile.expr(k)
emit("] = ")
compile.expr(v)
comma = true
end
emit("}")
end
function compile.paren(expr)
if type(expr)=="string" and not (expr:sub(1,1)=='"') then
emit(expr)
elseif syntax.form_p(expr) and expr[1]=="at" then
compile.expr(expr)
else
emit("(")
compile.expr(expr)
emit(")")
end
end
function compile.call(ast)
if type(ast[1])=="string" and ast[1]:match("^[.:]")then
compile.paren(ast[2])
emit(ast[1] .. "(")
compile.exprlist(rest(ast,2))
emit(")")
else
compile.paren(ast[1])
emit("(")
compile.exprlist(rest(ast))
emit(")")
end
end
function compile.special(kind)
return function(ast)
local c = syntax["special_" .. kind][ast[1]]
if c then
c(rest_unpacked(ast))
else
error("special " .. ast[1] .. " not allowed as " .. kind)
end
end
end
function compile.macro(kind)
return function(ast)
compile[kind](syntax.macroexpand(ast))
end
end
local function id(ast) compile.expr(ast[2]) end
local function neg(ast) emit("(-"); compile.expr(ast[2]); emit(")") end
local function recip(ast) emit("(1/"); compile.expr(ast[2]); emit(")") end
local function expo(ast) emit("math.exp("); compile.expr(ast[2]); emit(")") end
local function multi(ast)
compiler_seq(compile.expr, syntax.operators[ast[1]].emit or ast[1])(rest(ast))
end
local function consec_pairs(ast)
for i=2,#ast-1 do
emit("(")
compile.expr(ast[i])
emit(syntax.operators[ast[1]].emit or ast[1])
compile.expr(ast[i+1])
emit(i==#ast-1 and ")" or ") and ")
end
end
local function all_pairs(ast)
for i=2,#ast-1 do
for j=i+1,#ast do
emit("(")
compile.expr(ast[i])
emit(syntax.operators[ast[1]].emit or ast[1])
compile.expr(ast[j])
emit(i==#ast-1 and j==#ast and ")" or ") and ")
end
end
end
local zero, one, tru, ee = emitter(0), emitter(1), emitter("true"), emitter("math.exp(1)")
syntax.operators = {
["+"] = {zero, id, multi},
["-"] = {zero, neg, multi},
["*"] = {one, id, multi},
["/"] = {one, recip, multi},
["^"] = {ee, expo, multi},
["<"] = {tru, tru, consec_pairs},
[">"] = {tru, tru, consec_pairs},
["~="] = {tru, tru, all_pairs},
[">="] = {tru, tru, consec_pairs},
["<="] = {tru, tru, consec_pairs},
["="] = {tru, tru, consec_pairs, emit = "=="},
[".."] = {emitter("\"\""), id, multi, emit = " .."},
mod = {emit = "%"},
["and"] = {tru, id, multi, emit = " and "},
["or"] = {emitter("false"), id, multi, emit = " or "}
}
for _, op in pairs(syntax.operators) do
table.insert(op, 3, multi)
end
function compile.operexpr(ast)
local c = syntax.operators[ast[1]][math.min(#ast,4)]
if c then
if #ast>=3 then emit("(") end
c(ast)
if #ast>=3 then emit(")") end
else
error(string.format("cannot use %s as %d-ary", ast[1], #ast-1))
end
end
local function cannot_discard(thing)
return function(src)
if type(thing)=="function" then
thing = thing(src)
end
error(string.format("cannot discard value of %s: %s", thing, src))
end
end
compile.stat = compiler_alt {
vector_p = cannot_discard("vector constructor"),
table_p = cannot_discard("table constructor"),
atom_p = compiler_alt { label_p = emit, default = cannot_discard("atom")},
operexpr_p = cannot_discard("expression"),
special_stat_p = compile.special "stat",
macro_p = compile.macro "stat",
default = compile.call
}
compile.expr = compiler_alt {
vector_p = compile.vector,
table_p = compile.table,
atom_p = compiler_alt {
neg_p = function (ast) emit("(" .. ast .. ")") end,
default = emit },
operexpr_p = compile.operexpr,
special_expr_p = compile.special "expr",
macro_p = compile.macro "expr",
default = compile.call
}
compile.block = compiler_seq(compile.stat, "\n")
compile.exprlist = compiler_seq(compile.expr, ", ")
function compile.last(last)
last = syntax.macroexpand(last)
if syntax.special_last_p(last) then
compile.special("last")(last)
elseif syntax.special_stat_p(last) then
compile.stat(last)
else
emit("return ")
compile.expr(last)
end
end
function compile.retlast(seq)
local body, last = seplast(table.unpack(seq))
if not last then return end
compile.block(body)
if #body>0 then emit("\n") end
compile.last(last)
end
local function destruct(lhs)
local top, unpks = {}, {}
for _, v in ipairs(lhs) do
if syntax.vector_p(v) then
local g = gensym("__des")
top[#top+1] = g
local t, r = destruct(v)
unpks[#unpks+1] = {t, g}
for _, b in ipairs(r) do
unpks[#unpks+1] = b
end
else
top[#top+1] = v
end
end
return top, unpks
end
function compile.fn(opts, argspec, ...)
local args, unpks = destruct(argspec)
emit(string.format("%sfunction %s(",
opts.local_p and "local " or "",
opts.name or ""))
compile.exprlist(args)
emit(")")
compile.unpackings(unpks)
emit("\n")
compile.retlast(table.pack(...))
emit("\nend")
end
function compile.unpackings(binds, global)
for _, b in ipairs(binds) do
emit("\n")
if not global then emit("local ") end
compile.exprlist(b[1])
emit(" = ")
compile.expr({"table.unpack", b[2]})
end
end
function compile.bindings(fst, ...)
local global = fst ~= true
local spec = global and table.pack(fst, ...) or table.pack(...)
if #spec == 0 then return end
local lhs, rhs = {}, {}
for i=1,#spec,2 do
lhs[#lhs+1] = spec[i]
rhs[#rhs+1] = spec[i+1]
end
local v_p = syntax.vector_p
if v_p(lhs[#lhs]) then
local l = lhs[#lhs]
lhs[#lhs] = nil
for _, v in ipairs(l) do lhs[#lhs+1] = v end
if #spec%2 == 0 then
if v_p(rhs[#rhs]) then
local r = rhs[#rhs]
rhs[#rhs] = nil
for _, v in ipairs(r) do rhs[#rhs+1] = v end
else
rhs[#rhs] = {"table.unpack", rhs[#rhs]}
end
end
end
local top, unpks = destruct(lhs)
if not global then emit("local ") end
compile.exprlist(top)
if #rhs>0 then
emit(" = ")
compile.exprlist(rhs)
end
compile.unpackings(unpks, global)
end
function compile.if_(kind)
return function(...)
local exprs = table.pack(...)
for i = 1, #exprs-1, 2 do
emit("if ")
compile.expr(exprs[i])
emit(" then\n")
compile[kind](exprs[i+1])
emit("\n")
if i<#exprs-2 then emit("else") end
end
if #exprs % 2 == 1 then
emit("else\n")
compile[kind](exprs[#exprs])
emit("\n")
end
emit("end")
end
end
-- special forms
local function emit_string(s) emit(s:sub(2,#s-1)) end
syntax.special_expr = {
["lua"] = emit_string,
["fn"] = function(...) compile.fn({}, ...) end,
["at"] = function(x, ...)
compile.paren(x)
emit("[")
compiler_seq(compile.expr,"][")(table.pack(...))
emit("]")
end,
["do"] = function(...)
emit("(function ()\n")
compile.retlast(table.pack(...))
emit("\nend)()")
end,
["if"] = function(...)
local exprs = table.pack(...)
if #exprs == 1 then return exprs[1] end
if #exprs % 2 == 0 then exprs[#exprs + 1] = "nil" end
local ast = syntax.vector {exprs[#exprs]}
for i = #exprs-2, 1, -2 do
ast = {"or", {"and", exprs[i], syntax.vector {exprs[i+1]}}, ast}
end
compile.expr({"table.unpack", ast})
end
}
local function eval_fn_def(...)
local saved_pieces = pieces
clear()
compile.expr({"return", {"fn", ...}})
local c = code()
pieces = saved_pieces
return assert(load(c))()
end
local function undict(d)
for k, v in pairs(d) do
return k, v
end
end
syntax.special_stat = {
["lua"] = emit_string,
["for"] = function (spec, ...)
local vars, kind, range
local unpks = {}
if syntax.vector_p(spec) then
kind = " = "
if #spec == 0 then spec = {{"/", 0}} end
if #spec == 1 then table.insert(spec, 1, "_") end
if #spec == 2 then table.insert(spec, 2, 1) end
vars, range = {spec[1]}, rest(spec)
elseif syntax.form_p(spec) then
kind = " in "
vars, range = seplast(table.unpack(spec))
range = {range}
elseif syntax.table_p(spec) then
kind = " in "
vars, range = undict(spec)
if syntax.table_p(vars) then
vars = table.pack(undict(vars))
range = {{"pairs", range}}
else
if not syntax.form_p(vars) then vars = {"_", vars} end
range = {{"ipairs", range}}
end
else
error("for: unsupported iteration specificiation")
end
vars, unpks = destruct(vars)
emit("for ")
compile.exprlist(vars)
emit(kind)
compile.exprlist(range)
emit(" do")
compile.unpackings(unpks)
emit("\n")
compile.block(table.pack(...))
emit("\nend")
end,
["while"] = function (cond, ...)
emit("while ")
compile.expr(cond)
emit(" do\n")
compile.block(table.pack(...))
emit("\nend")
end,
["repeat-until"] = function (cond, ...)
emit("repeat\n")
compile.block(table.pack(...))
emit("\nuntil ")
compile.expr(cond)
end,
["do"] = function(...) compile.block(table.pack(...)) end,
["goto"] = function(lbl) emit("goto " .. lbl) end,
["return"] = function (...)
emit("return ")
compile.exprlist(table.pack(...))
end,
["set!"] = compile.bindings,
["def!"] = compile.bindings,
["break"] = function () emit("break") end,
["def"] = function (...) compile.bindings(true, ...) end,
["defmacro"] = function (name, ...)
syntax.global_macros[name] = eval_fn_def(...)
end,
["macrolet"] = function (specs, ...)
syntax.macros = setmetatable({}, {__index = syntax.macros})
for _, spec in ipairs(specs) do
syntax.macros[spec[1]] = eval_fn_def(rest_unpacked(spec))
end
compile.retlast(table.pack(...))
syntax.macros = getmetatable(syntax.macros).__index
end,
["if"] = compile.if_ "stat",
["defn!"] = function (id,...) compile.fn({name = id}, ...) end,
["defn"] = function (id,...)
compile.fn({name = id, local_p = true}, ...)
end
}
syntax.special_last = {
["lua"] = emit_string,
["if"] = compile.if_ "last",
["do"] = function(...) compile.retlast(table.pack(...)) end,
}
-- cli, repl
local function addext(path)
return path:find("%.") and path or (path .. ".ci")
end
function syntax.tolua(ast, kind)
kind = kind or "expr"
clear()
compile[kind](assert(ast))
return code()
end
function compile.file(path)
local file = io.open(addext(path), "r")
return syntax.tolua(syntax.parse(file:read("*a")), "block")
end
function runfile(path)
assert(load(compile.file(path)))()
end
local path_sep = package.config:sub(1,1)
runfile(os.getenv("HOME") .. "/path/to/stdlib.ci")
if #arg>0 then
if arg[1]=="-c" then
local path = addext(arg[2])
local file = io.open(path:sub(1,-4) .. ".ci.lua", "w")
file:write(compile.file(path))
file:close()
elseif arg[1]=="-d" then
print(compile.file(addext(arg[2])))
else
local path = addext(arg[1])
table.remove(arg, 1)
runfile(path)
end
else -- repl
local function prompt() io.write("~> ") end
runfile(os.getenv("HOME") .. "/path/to/pretty.ci")
prompt()
local source, ast = ""
for line in io.lines() do
if line == "." then
source = ""
prompt()
else
source = source .. "\n" .. line
local ok, ast = pcall(syntax.parse, source)
if ok then
source = ""
local ok, err = pcall(function()
local r = table.pack(assert(load(syntax.tolua(ast, "retlast")))())
_ = r[1]
print(table.unpack(table.map(syntax.pretty, r)))
end)
if not ok then print(err) end
prompt()
end
end
end
print()
end
;; pretty printer
;; TODO: add line breaks
(defn pretty (x seen)
(case (type x)
`string (gsub (gsub (format "%q" x) "\\\n" "\\n") "\\9" "\\t")
`table (if (at seen x)
(.. "<" (tostring x) ">")
(do
(set-at! seen x true)
(def* vec (list (for {v x}) (pretty v seen))
vec-str (.. "[" (concat vec " ") "]")
dic (list (for {{k v} x}) (unless (at vec k))
(..
(if (and (= (type k) `string)
(match k "^[^%s%(%)%[%]%{%}%\"]+$"))
k
(pretty k seen))
" "
(pretty v seen)))
dic-str (.. "{" (concat dic " ") "}"))
(if (= (# dic) 0) vec-str
(= (# vec) 0) dic-str
(format "(& %s %s)" vec-str dic-str))))
`number (tostring x)
(.. "<" (tostring x) ">")))
(defn! syntax.pretty (x) (pretty x {}))
;;; Macros:
;; basic convenience
(defmacro inc! (x y ...) `(set! ,x (+ ,x ,(or y 1) ,...)))
(defmacro dec! (x y ...) `(set! ,x (- ,x ,(or y 1) ,...)))
(defmacro mul! (x y ...) `(set! ,x (* ,x ,y ,...)))
(defmacro add! (l ...)
(def form [...])
(for [i (# form)]
(set! (at form i) `(set! (at ,l (+ (# ,l) 1)) ,(at form i))))
(table.insert form 1 `do)
form)
(defmacro swap! (...)
(def vars [...] form `(set!))
(for [i 1 (# vars) 2]
(for {j [i (+ i 1) (+ i 1) i]}
(add! form (at vars j))))
form)
(defmacro when (cond ...) `(if ,cond (do ,...)))
(defmacro unless (cond ...) `(when (not ,cond) ,...))
(defmacro set-at! (t ...)
(def ix [...] v (gensym))
(def exp `(do (def ,v ,t)))
(for [i (- (# ix) 2)]
(def j (gensym))
(add! exp `(def ,j ,(at ix i)))
(add! exp `(when (= (at ,v ,j) nil) (set! (at ,v ,j) {})))
(add! exp `(set! ,v (at ,v ,j))))
(add! exp `(set! (at ,v ,(at ix (- (# ix) 1))) ,(at ix (# ix))))
exp)
(defmacro def* (...)
(def binds [...] exp `(do))
(for [i 1 (# binds) 2]
(add! exp `(def ,(at binds i) ,(at binds (+ i 1)))))
exp)
(defmacro let (vars ...) `(do (def ,(table.unpack vars)) ,...))
(defmacro let* (vars ...) `(do (def* ,(table.unpack vars)) ,...))
(defmacro if-let (...)
(def forms [...])
(def r (mod (# forms) 2))
(def exp (when (= r 1) (at forms (# forms))))
(for [i (- (# forms) r 1) 1 -2]
(set! exp `(let ,(at forms i)
(if ,(let [v (at forms i 1)]
(if (syntax.vector? v)
`(or ,(table.unpack v))
v))
,(at forms (+ i 1))
,exp))))
exp)
(defmacro when-let (bind ...) `(if-let ,bind (do ,...)))
(defmacro while-let (bind ...) `(while true (if-let ,bind (do ,...) (break))))
(defmacro case (expr ...)
(def x (gensym) clauses [...])
(for [i 1 (- (# clauses) 1) 2]
(set! (at clauses i) `(= ,x ,(at clauses i))))
`(let [,x ,expr] (if ,(table.unpack clauses))))
;; threading
(defmacro ~> (a ...)
(for {x [...]}
(if (syntax.atom? x)
(set! a [x a])
(do (add! x a) (set! a x))))
a)
;; (defmacro <~ (...)
;; (def args [...] rev [])
;; (for [i (# args) 1 -1] (add! rev (at args i)))
;; `(~> ,(table.unpack rev)))
(defmacro :> (a ...)
(for {x [...]}
(if (syntax.atom? x)
(set! a [x a])
(do (table.insert x 2 a) (set! a x))))
a)
(defmacro <~ (...) `(~> ,(table.unpack (table.reverse [...]))))
(defmacro <: (...) `(:> ,(table.unpack (table.reverse [...]))))
;; comprehensions
(defmacro loop (update init ...)
(def args [...] vars [] bindings [])
(for {(i v) (if (syntax.form? init) init [init])}
(set! (at vars i) (gensym))
(add! bindings (at vars i) v))
(def last [update (table.unpack vars)])
(add! last (at args (# args)))
(when (syntax.form? update)
(def upd (gensym))
(set! (at last 1) upd)
(set! last `(macrolet [(,upd ,(table.unpack update))] ,last)))
(set! (at args (# args)) last)
`(let ,bindings (<~ ,(table.unpack args)) (return ,(table.unpack vars))))
(defmacro fold (op init ...)
`(loop ((x y) `(set! ,x (,,op ,x ,y))) ,init ,...))
(defmacro dict (...)
`(loop ((t x) `(set! (at ,t ,(at x 1)) ,(at x 2))) {} ,...))
(defmacro group (...)
`(loop ((t x)
(def k (gensym))
`(let [,k ,(at x 1)]
(when (= (at ,t ,k) nil) (set! (at ,t ,k) []))
(add! (at ,t ,k) ,(at x 2))))
{} ,...))
(defmacro freq (...)
`(loop ((t x)
(def k (gensym))
`(let [,k ,x]
(if (= (at ,t ,k) nil)
(set! (at ,t ,k) 1)
(inc! (at ,t ,k)))))
{} ,...))
(defmacro list (...) `(loop add! [] ,...))
(defmacro sum (...) `(loop inc! 0 ,...))
(defmacro prod (...) `(loop mul! 1 ,...))
(defmacro first (...) `(loop return () ,...))
(defmacro any? (...) `(loop ((_ x) `(when ,x (return true))) false ,...))
(defmacro all? (...) `(loop ((_ x) `(unless ,x (return false))) true ,...))
(defmacro count (...) (def form `(sum ,...)) (add! form 1) form)
(defmacro choose (...)
`(loop ((t i x)
`(do (inc! ,i)
(when (= 1 (math.random ,i))
(set! ,t ,x))))
(nil 0) ,...))
(defmacro best-by (cmp ini ...)
`(loop ((w b x)
(def y (gensym) [k v])
(if (syntax.vector? x)
(set! k (at x 1) v (at x 2))
(set! k y v x))
`(let [,y ,v] (when (,,cmp ,y ,b) (set! ,w ,k ,b ,y))))
(nil ,ini)
,...))
(defmacro min-by (...) `(best-by `< (/ 0) ,...))
(defmacro max-by (...) `(best-by `> (- (/ 0)) ,...))
;; generators
(defmacro generator (...) `(coroutine.wrap (fn () ,...)))
(defmacro stream (...) `(generator (loop yield () ,...)))
;; timing
(defmacro time (...)
(def t (gensym) r (gensym))
`(do
(def ,t (os.clock))
(def ,r [(do ,...)])
(io.write (string.format "%.4f seconds\n" (- (os.clock) ,t)))
(table.unpack ,r)))
;;; utilities
(defn! io.slurp (p a m)
(when (= a "b") (swap! a m))
(def* f (open p (if (= m "b") "rb" "r")) r (f:read (or a "*a")))
(f:close)
r)
(defn! io.spit (p s b)
(def f (open p (if (= b "b") "wb" "w")))
(f:write (case (type s)
`string s
`table (table.concat s "\n")
(tostring s)))
(f:close))
(defmacro int (n) `(lua ,(format "\"%dLL\"" n)))
;;; Adjust *my* standard library, open most modules
(set! fn nil) ; not needed in cicio
(defn! mod1 (k n) (+ (mod (- k 1) n) 1))
(defn! reverse (x)
(case (type x)
`string (string.reverse x)
`table (table.reverse x)))
(for {{module skips}
`{table [unpack reverse] io [type remove]
math [] coroutine [] string [reverse]}}
(def skip? (dict (for {skip skips}) [skip true]))
(for {{k v} (at _G module)}
(unless (at skip? k)
(set! (at _G k) v))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment