|
// lisp.go |
|
package main |
|
|
|
import ( |
|
"bufio" |
|
"fmt" |
|
"os" |
|
"strconv" |
|
"strings" |
|
) |
|
|
|
// ---------------------------------------------------------- |
|
// Types & environment |
|
// ---------------------------------------------------------- |
|
|
|
type any = interface{} |
|
|
|
// Env represents a lexical scope. |
|
type Env struct { |
|
vars map[string]any |
|
outer *Env |
|
} |
|
|
|
// Find a symbol in the chain of environments. |
|
func (e *Env) find(key string) (*Env, bool) { |
|
for cur := e; cur != nil; cur = cur.outer { |
|
if _, ok := cur.vars[key]; ok { |
|
return cur, true |
|
} |
|
} |
|
return nil, false |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// Lexer – turn source text into tokens |
|
// ---------------------------------------------------------- |
|
|
|
type token struct{ typ, val string } |
|
|
|
func tokenize(src string) []token { |
|
var out []token |
|
i := 0 |
|
for i < len(src) { |
|
c := src[i] |
|
switch { |
|
case c == ' ' || c == '\t' || c == '\n' || c == '\r': |
|
i++ |
|
case c == '(' || c == ')': |
|
out = append(out, token{typ: string(c), val: string(c)}) |
|
i++ |
|
case c == '"': |
|
// string literal |
|
j := i + 1 |
|
var sb strings.Builder |
|
for j < len(src) && src[j] != '"' { |
|
if src[j] == '\\' && j+1 < len(src) { |
|
switch src[j+1] { |
|
case 'n': |
|
sb.WriteByte('\n') |
|
case 't': |
|
sb.WriteByte('\t') |
|
case '"': |
|
sb.WriteByte('"') |
|
case '\\': |
|
sb.WriteByte('\\') |
|
default: |
|
sb.WriteByte('\\') |
|
sb.WriteByte(src[j+1]) |
|
} |
|
j += 2 |
|
} else { |
|
sb.WriteByte(src[j]) |
|
j++ |
|
} |
|
} |
|
out = append(out, token{typ: "string", val: sb.String()}) |
|
if j < len(src) && src[j] == '"' { |
|
j++ // skip closing quote |
|
} |
|
i = j |
|
default: |
|
// number or symbol |
|
start := i |
|
for i < len(src) && !strings.ContainsRune(" \t\n\r()", rune(src[i])) { |
|
i++ |
|
} |
|
out = append(out, token{typ: "atom", val: src[start:i]}) |
|
} |
|
} |
|
return out |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// AST node for true string literals |
|
// ---------------------------------------------------------- |
|
|
|
type lispString struct { |
|
value string |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// Parser – build a nested []any AST |
|
// ---------------------------------------------------------- |
|
|
|
type parser struct{ toks []token } |
|
|
|
func (p *parser) peek() *token { |
|
if len(p.toks) == 0 { |
|
return nil |
|
} |
|
return &p.toks[0] |
|
} |
|
func (p *parser) consume() token { |
|
t := p.toks[0] |
|
p.toks = p.toks[1:] |
|
return t |
|
} |
|
|
|
// parse returns a single expression (list or atom) |
|
func (p *parser) parse() (any, error) { |
|
if p.peek() == nil { |
|
return nil, fmt.Errorf("unexpected EOF") |
|
} |
|
t := p.consume() |
|
if t.typ == "(" { |
|
var list []any |
|
for p.peek() != nil && p.peek().typ != ")" { |
|
elem, err := p.parse() |
|
if err != nil { |
|
return nil, err |
|
} |
|
list = append(list, elem) |
|
} |
|
if p.peek() == nil { |
|
return nil, fmt.Errorf("missing ')'") |
|
} |
|
p.consume() // discard ')' |
|
return list, nil |
|
} |
|
if t.typ == ")" { |
|
return nil, fmt.Errorf("unexpected ')'") |
|
} |
|
// atom → int | float | string (literal) | symbol |
|
if i, err := strconv.Atoi(t.val); err == nil { |
|
return i, nil |
|
} |
|
if f, err := strconv.ParseFloat(t.val, 64); err == nil { |
|
return f, nil |
|
} |
|
if t.typ == "string" { |
|
// true string literal |
|
return lispString{value: t.val}, nil |
|
} |
|
// otherwise it is a symbol |
|
return t.val, nil |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// Lambda (anonymous function) value |
|
// ---------------------------------------------------------- |
|
|
|
type lambda struct { |
|
params []string |
|
body any |
|
closure *Env |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// Helpers |
|
// ---------------------------------------------------------- |
|
|
|
func isTruthy(v any) bool { |
|
switch x := v.(type) { |
|
case nil: |
|
return false |
|
case bool: |
|
return x |
|
case int: |
|
return x != 0 |
|
case float64: |
|
return x != 0.0 |
|
case string: |
|
return x != "" |
|
case []any: |
|
return len(x) != 0 |
|
default: |
|
return true |
|
} |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// Evaluation |
|
// ---------------------------------------------------------- |
|
|
|
func eval(expr any, env *Env) (any, error) { |
|
switch v := expr.(type) { |
|
case int, float64: |
|
// numeric literals |
|
return v, nil |
|
|
|
case lispString: |
|
// true string literals – return the raw Go string |
|
return v.value, nil |
|
|
|
case string: |
|
// plain symbol – look it up in the environment |
|
if val, ok := env.vars[v]; ok { |
|
return val, nil |
|
} |
|
if outer, ok := env.find(v); ok { |
|
return outer.vars[v], nil |
|
} |
|
return nil, fmt.Errorf("unbound symbol: %s", v) |
|
|
|
case []any: // list → special form or ordinary call |
|
if len(v) == 0 { |
|
return nil, fmt.Errorf("empty list") |
|
} |
|
// first element may be a special form |
|
if sym, ok := v[0].(string); ok { |
|
switch sym { |
|
case "def": |
|
// (def name expr) |
|
if len(v) != 3 { |
|
return nil, fmt.Errorf("def needs exactly 2 arguments") |
|
} |
|
name, ok := v[1].(string) |
|
if !ok { |
|
return nil, fmt.Errorf("def: first argument must be a symbol") |
|
} |
|
val, err := eval(v[2], env) |
|
if err != nil { |
|
return nil, err |
|
} |
|
env.vars[name] = val |
|
return val, nil |
|
|
|
case "if": |
|
// (if test then else) |
|
if len(v) != 4 { |
|
return nil, fmt.Errorf("if needs exactly 3 arguments") |
|
} |
|
cond, err := eval(v[1], env) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if isTruthy(cond) { |
|
return eval(v[2], env) |
|
} |
|
return eval(v[3], env) |
|
|
|
case "fn": |
|
// (fn (arg1 arg2 ...) body) |
|
if len(v) != 3 { |
|
return nil, fmt.Errorf("fn needs exactly 2 arguments") |
|
} |
|
paramList, ok := v[1].([]any) |
|
if !ok { |
|
return nil, fmt.Errorf("fn: first argument must be a list of parameters") |
|
} |
|
params := make([]string, len(paramList)) |
|
for i, p := range paramList { |
|
sym, ok := p.(string) |
|
if !ok { |
|
return nil, fmt.Errorf("fn: parameters must be symbols") |
|
} |
|
params[i] = sym |
|
} |
|
return &lambda{params: params, body: v[2], closure: env}, nil |
|
|
|
case "let": |
|
// (let ((a 1) (b 2) ...) body...) |
|
if len(v) < 3 { |
|
return nil, fmt.Errorf("let needs bindings and at least one body expression") |
|
} |
|
bindList, ok := v[1].([]any) |
|
if !ok { |
|
return nil, fmt.Errorf("let: first argument must be a list of bindings") |
|
} |
|
child := &Env{vars: map[string]any{}, outer: env} |
|
for _, b := range bindList { |
|
pair, ok := b.([]any) |
|
if !ok || len(pair) != 2 { |
|
return nil, fmt.Errorf("let: each binding must be a (name value) pair") |
|
} |
|
name, ok := pair[0].(string) |
|
if !ok { |
|
return nil, fmt.Errorf("let: binding name must be a symbol") |
|
} |
|
val, err := eval(pair[1], env) |
|
if err != nil { |
|
return nil, err |
|
} |
|
child.vars[name] = val |
|
} |
|
var result any |
|
for _, bodyExpr := range v[2:] { |
|
var err error |
|
result, err = eval(bodyExpr, child) |
|
if err != nil { |
|
return nil, err |
|
} |
|
} |
|
return result, nil |
|
} |
|
// not a special form → ordinary function call |
|
} |
|
|
|
// ----- ordinary function call ----- |
|
// evaluate operator |
|
op, err := eval(v[0], env) |
|
if err != nil { |
|
return nil, err |
|
} |
|
// evaluate arguments |
|
args := make([]any, len(v)-1) |
|
for i := 1; i < len(v); i++ { |
|
a, err := eval(v[i], env) |
|
if err != nil { |
|
return nil, err |
|
} |
|
args[i-1] = a |
|
} |
|
// apply |
|
switch fn := op.(type) { |
|
case func([]any) (any, error): |
|
return fn(args) |
|
case *lambda: |
|
if len(args) != len(fn.params) { |
|
return nil, fmt.Errorf("expected %d args, got %d", len(fn.params), len(args)) |
|
} |
|
callEnv := &Env{vars: map[string]any{}, outer: fn.closure} |
|
for i, name := range fn.params { |
|
callEnv.vars[name] = args[i] |
|
} |
|
return eval(fn.body, callEnv) |
|
default: |
|
return nil, fmt.Errorf("not a function: %v", op) |
|
} |
|
} |
|
return nil, fmt.Errorf("unexpected expression type %T", expr) |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// Built‑ins |
|
// ---------------------------------------------------------- |
|
|
|
func builtin(name string) (any, error) { |
|
switch name { |
|
case "+": |
|
return func(args []any) (any, error) { |
|
var sum float64 |
|
for _, a := range args { |
|
switch n := a.(type) { |
|
case int: |
|
sum += float64(n) |
|
case float64: |
|
sum += n |
|
default: |
|
return nil, fmt.Errorf("+ works only on numbers") |
|
} |
|
} |
|
if sum == float64(int64(sum)) { |
|
return int(sum), nil |
|
} |
|
return sum, nil |
|
}, nil |
|
case "-": |
|
return func(args []any) (any, error) { |
|
if len(args) == 0 { |
|
return nil, fmt.Errorf("- requires at least one argument") |
|
} |
|
var acc float64 |
|
switch n := args[0].(type) { |
|
case int: |
|
acc = float64(n) |
|
case float64: |
|
acc = n |
|
default: |
|
return nil, fmt.Errorf("- works only on numbers") |
|
} |
|
if len(args) == 1 { |
|
return -acc, nil |
|
} |
|
for _, a := range args[1:] { |
|
switch n := a.(type) { |
|
case int: |
|
acc -= float64(n) |
|
case float64: |
|
acc -= n |
|
default: |
|
return nil, fmt.Errorf("- works only on numbers") |
|
} |
|
} |
|
if acc == float64(int64(acc)) { |
|
return int(acc), nil |
|
} |
|
return acc, nil |
|
}, nil |
|
case "*": |
|
return func(args []any) (any, error) { |
|
prod := 1.0 |
|
for _, a := range args { |
|
switch n := a.(type) { |
|
case int: |
|
prod *= float64(n) |
|
case float64: |
|
prod *= n |
|
default: |
|
return nil, fmt.Errorf("* works only on numbers") |
|
} |
|
} |
|
if prod == float64(int64(prod)) { |
|
return int(prod), nil |
|
} |
|
return prod, nil |
|
}, nil |
|
case "/": |
|
return func(args []any) (any, error) { |
|
if len(args) == 0 { |
|
return nil, fmt.Errorf("/ requires at least one argument") |
|
} |
|
var cur float64 |
|
switch n := args[0].(type) { |
|
case int: |
|
cur = float64(n) |
|
case float64: |
|
cur = n |
|
default: |
|
return nil, fmt.Errorf("/ works only on numbers") |
|
} |
|
if len(args) == 1 { |
|
return 1.0 / cur, nil |
|
} |
|
for _, a := range args[1:] { |
|
switch n := a.(type) { |
|
case int: |
|
cur /= float64(n) |
|
case float64: |
|
cur /= n |
|
default: |
|
return nil, fmt.Errorf("/ works only on numbers") |
|
} |
|
} |
|
if cur == float64(int64(cur)) { |
|
return int(cur), nil |
|
} |
|
return cur, nil |
|
}, nil |
|
case "=": |
|
return func(args []any) (any, error) { |
|
if len(args) < 2 { |
|
return nil, fmt.Errorf("= needs at least two arguments") |
|
} |
|
first := fmt.Sprintf("%v", args[0]) |
|
for _, a := range args[1:] { |
|
if fmt.Sprintf("%v", a) != first { |
|
return false, nil |
|
} |
|
} |
|
return true, nil |
|
}, nil |
|
case "<", "<=", ">", ">=": |
|
return comparisonBuiltin(name) |
|
default: |
|
return nil, fmt.Errorf("unknown builtin %s", name) |
|
} |
|
} |
|
|
|
// comparisonBuiltin creates <, <=, >, >= |
|
type cmpFunc func(a, b float64) bool |
|
|
|
func comparisonBuiltin(op string) (any, error) { |
|
var fn cmpFunc |
|
switch op { |
|
case "<": |
|
fn = func(a, b float64) bool { return a < b } |
|
case "<=": |
|
fn = func(a, b float64) bool { return a <= b } |
|
case ">": |
|
fn = func(a, b float64) bool { return a > b } |
|
case ">=": |
|
fn = func(a, b float64) bool { return a >= b } |
|
} |
|
return func(args []any) (any, error) { |
|
if len(args) != 2 { |
|
return nil, fmt.Errorf("%s expects exactly 2 arguments", op) |
|
} |
|
toFloat := func(v any) (float64, error) { |
|
switch n := v.(type) { |
|
case int: |
|
return float64(n), nil |
|
case float64: |
|
return n, nil |
|
default: |
|
return 0, fmt.Errorf("%s works only on numbers", op) |
|
} |
|
} |
|
a, err := toFloat(args[0]) |
|
if err != nil { |
|
return nil, err |
|
} |
|
b, err := toFloat(args[1]) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return fn(a, b), nil |
|
}, nil |
|
} |
|
|
|
// ---------------------------------------------------------- |
|
// REPL |
|
// ---------------------------------------------------------- |
|
|
|
func main() { |
|
// Global environment pre‑populated with built‑ins |
|
global := &Env{vars: map[string]any{}, outer: nil} |
|
for _, name := range []string{"+", "-", "*", "/", "=", "<", "<=", ">", ">="} { |
|
b, err := builtin(name) |
|
if err != nil { |
|
panic(err) |
|
} |
|
global.vars[name] = b |
|
} |
|
|
|
fmt.Println("tiny Lisp REPL – Ctrl‑D or \"quit\" to exit") |
|
scanner := bufio.NewScanner(os.Stdin) |
|
for { |
|
fmt.Print("> ") |
|
if !scanner.Scan() { |
|
// EOF (Ctrl‑D) |
|
fmt.Println() |
|
break |
|
} |
|
line := strings.TrimSpace(scanner.Text()) |
|
if line == "" || line == "quit" || line == "(quit)" { |
|
break |
|
} |
|
parsed, err := (&parser{toks: tokenize(line)}).parse() |
|
if err != nil { |
|
fmt.Println("parse error:", err) |
|
continue |
|
} |
|
val, err := eval(parsed, global) |
|
if err != nil { |
|
fmt.Println("error:", err) |
|
continue |
|
} |
|
fmt.Printf("%#v\n", val) |
|
} |
|
} |