Skip to content

Instantly share code, notes, and snippets.

@L-as
Last active July 7, 2025 20:11
Show Gist options
  • Save L-as/23a71779ffd7486187903c75e73b9873 to your computer and use it in GitHub Desktop.
Save L-as/23a71779ffd7486187903c75e73b9873 to your computer and use it in GitHub Desktop.
Haskell version of Oleg Kiselyov's sound_lazy.ml from https://okmij.org/ftp/ML/generalization.html rewritten mechanically, big differences being much of the code here is lazy, we use STRef, and global variables are put into a record.
{-# LANGUAGE GHC2021, NamedFieldPuns, LambdaCase, RebindableSyntax, RecordWildCards, BlockArguments, MagicHash #-}
{-# OPTIONS_GHC -Wall -Wno-name-shadowing -Wno-missing-signatures #-}
module Sound_lazy where
import Data.STRef
import Control.Monad.ST
import Prelude hiding (negate)
import Data.Char
import Control.Monad
import Data.Maybe
import Control.Exception
import GHC.Exts (reallyUnsafePtrEquality, isTrue#)
ifThenElse :: Bool -> a -> a -> a
ifThenElse True x _ = x
ifThenElse False _ y = y
type Varname = String
data Exp
= Var Varname
| App Exp Exp
| Lam Varname Exp
| Let Varname Exp Exp
type QName = String
type Level = Int
generic_level = 100000000 :: Level
marked_level = 0 - 1 :: Level
data Typ s
= TVar (STRef s (Tv s))
| TArrow (Typ s) (Typ s) (STRef s Levels)
data Tv s
= Unbound String Level
| Link (Typ s)
data Levels
= Levels { level_old :: Level, level_new :: Level }
deriving Show
set_level_new ls level = modifySTRef ls (\ls -> ls{level_new = level})
set_level_old ls level = modifySTRef ls (\ls -> ls{level_old = level})
-- defining something like ! from OCaml
negate = readSTRef
(=:) = writeSTRef
repr :: Typ s -> ST s (Typ s)
repr t@(TVar tvr) =
-tvr >>= \case
Link t -> do
t <- repr t
tvr =: Link t
pure t
_ -> pure t
repr t = pure t
get_level :: Typ s -> ST s Level
get_level (TVar tvr) = -tvr >>= \case
Unbound _ l -> pure l
_ -> error "impossible"
get_level (TArrow _ _ ls) = level_new <$> -ls
data GlobalState s = S
{ gensym_counter :: STRef s Int
, current_level :: STRef s Level
, to_be_level_adjusted :: STRef s [Typ s]
}
incr ref = -ref >>= \n -> ref =: (n + 1)
decr ref = -ref >>= \n -> ref =: (n - 1)
gensym :: GlobalState s -> ST s String
gensym S{..} = do
n <- -gensym_counter
incr gensym_counter
pure $ if n < 26
then [chr (ord 'a' + n)]
else 't' : show n
reset_type_variables = do
current_level <- newSTRef 1
gensym_counter <- newSTRef 0
to_be_level_adjusted <- newSTRef []
pure S{..}
enter_level S{..} = incr current_level
leave_level S{..} = decr current_level
newvar :: GlobalState s -> ST s (Typ s)
newvar S{..} = do
name <- gensym S{..}
level <- -current_level
TVar <$> newSTRef (Unbound name level)
new_arrow :: GlobalState s -> Typ s -> Typ s -> ST s (Typ s)
new_arrow S{..} ty1 ty2 = do
level <- -current_level
ls <- newSTRef (Levels{level_old = level, level_new = level})
pure (TArrow ty1 ty2 ls)
cycle_free :: GlobalState s -> Typ s -> ST s ()
cycle_free S{..} (TVar tvr) = -tvr >>= \case
Unbound{} -> pure ()
Link ty -> cycle_free S{..} ty
cycle_free S{..} (TArrow t1 t2 ls) = do
level <- level_new <$> -ls
when (level == marked_level) $ error "occurs check"
set_level_new ls marked_level
cycle_free S{..} t1
cycle_free S{..} t2
set_level_new ls level
update_level :: GlobalState s -> Level -> Typ s -> ST s ()
update_level S{} l (TVar tvr) = do
-tvr >>= \case
Unbound n l' -> do
when (l' == generic_level) $ error "impossible"
when (l < l') $ tvr =: Unbound n l
_ -> error "impossible"
update_level S{..} l ty@(TArrow _ _ ls) = do
l' <- level_new <$> -ls
when (l' == generic_level) $ error "impossible"
when (l' == marked_level) $ error "occurs check"
when (l < l') do
level_old <$> -ls >>= \level_old -> when (l' == level_old) $ modifySTRef to_be_level_adjusted (ty :)
set_level_new ls l
ptrEqual :: a -> a -> Bool
ptrEqual !x !y = isTrue# (reallyUnsafePtrEquality x y)
unify :: GlobalState s -> Typ s -> Typ s -> ST s ()
unify S{..} t1 t2
| True <- ptrEqual t1 t2 = pure ()
| otherwise = do
t1 <- repr t1
t2 <- repr t2
case (t1, t2) of
(TVar tv1, TVar tv2) -> (,) <$> -tv1 <*> -tv2 >>= \case
(Unbound _ l1, Unbound _ l2) ->
if tv1 == tv2 then pure ()
else
if l1 > l2
then tv1 =: Link t2
else tv2 =: Link t1
_ -> error "impossible"
(TVar tv, t') -> -tv >>= \case
Unbound _ l -> one_unbound l t' tv
_ -> error "impossible"
(t', TVar tv) -> -tv >>= \case
Unbound _ l -> one_unbound l t' tv
_ -> error "impossible"
(TArrow tyl1 tyl2 ll, TArrow tyr1 tyr2 lr) -> do
llevel_new <- level_new <$> -ll
rlevel_new <- level_new <$> -lr
when (llevel_new == marked_level || rlevel_new == marked_level) $
error "cycle: occurs check"
let min_level = min llevel_new rlevel_new
set_level_new ll marked_level
set_level_new lr marked_level
unify_lev S{..} min_level tyl1 tyr1
unify_lev S{..} min_level tyl2 tyr2
set_level_new ll min_level
set_level_new lr min_level
where
one_unbound l t' tv = do
update_level S{..} l t'
tv =: Link t'
unify_lev S{..} l ty1 ty2 = do
ty1 <- repr ty1
update_level S{..} l ty1
unify S{..} ty1 ty2
type Env s = [(Varname, Typ s)]
force_delayed_adjustments S{..} = do
old <- -to_be_level_adjusted
new <- foldl' (\acc ty -> join $ adjust_one <$> acc <*> pure ty) (pure []) old
to_be_level_adjusted =: new
where
loop acc level ty =
repr ty >>= \case
TVar tvr -> -tvr >>= \case
Unbound name l | l > level ->
tvr =: Unbound name level >> pure acc
_ -> pure acc
ty@(TArrow _ _ ls) -> level_new <$> -ls >>= \case
l | l == marked_level -> error "occurs check"
l -> do
when (l > level) $ set_level_new ls level
adjust_one acc ty
adjust_one acc ty@(TArrow ty1 ty2 ls) =
(,) <$> -ls <*> -current_level >>= \case
(Levels{..}, current_level) | level_old <= current_level -> pure (ty : acc)
(Levels{..}, _) | level_old == level_new -> pure acc
(Levels{..}, _) -> do
let level = level_new
set_level_new ls marked_level
acc <- loop acc level ty1
acc <- loop acc level ty2
set_level_new ls level
set_level_old ls level
pure acc
adjust_one _ _ = error "impossible"
gen :: GlobalState s -> Typ s -> ST s ()
gen S{..} = \ty -> do
force_delayed_adjustments S{..}
loop ty
where
loop ty = do
current_level <- -current_level
repr ty >>= \case
TVar tvr -> -tvr >>= \case
Unbound name l | l > current_level
-> tvr =: Unbound name generic_level
_ -> pure ()
TArrow ty1 ty2 ls -> level_new <$> -ls >>= \case
l | l > current_level -> do
ty1 <- repr ty1
ty2 <- repr ty2
loop ty1
loop ty2
l <- max <$> get_level ty1 <*> get_level ty2
set_level_old ls l
set_level_new ls l
_ -> pure ()
inst :: GlobalState s -> Typ s -> ST s (Typ s)
inst S{..} = \ty -> fst <$> loop [] ty
where
loop subst ty@(TVar tvr) = -tvr >>= \case
Unbound name l | l == generic_level ->
case lookup name subst of
Just tv ->
pure (tv, subst)
Nothing -> do
tv <- newvar S{..}
pure (tv, (name, tv) : subst)
Link ty -> loop subst ty
_ -> pure (ty, subst)
loop subst ty@(TArrow ty1 ty2 ls) = level_new <$> -ls >>= \case
l | l == generic_level -> do
(ty1, subst) <- loop subst ty1
(ty2, subst) <- loop subst ty2
pure (,) <*> new_arrow S{..} ty1 ty2 <*> pure subst
_ -> pure (ty, subst)
typeof :: GlobalState s -> Env s -> Exp -> ST s (Typ s)
typeof S{..} env (Var x) = inst S{..} (fromJust $ lookup x env)
typeof S{..} env (Lam x e) = do
ty_x <- newvar S{..}
ty_e <- typeof S{..} ((x, ty_x) : env) e
new_arrow S{..} ty_x ty_e
typeof S{..} env (App e1 e2) = do
ty_fun <- typeof S{..} env e1
ty_arg <- typeof S{..} env e2
ty_res <- newvar S{..}
join $ unify S{..} ty_fun <$> new_arrow S{..} ty_arg ty_res
pure ty_res
typeof S{..} env (Let x e e2) = do
enter_level S{..}
ty_e <- typeof S{..} env e
leave_level S{..}
gen S{..} ty_e
typeof S{..} ((x, ty_e) : env) e2
data PureTyp
= PureTVar !String !Level
| PureTArrow !PureTyp !PureTyp !Levels
deriving Show
-- not present in OCaml version but necessary for Show instance
-- and also forces the type to get all exceptions thrown
purify_typ (TVar tv) = join $ purify_tv <$> -tv
purify_typ (TArrow ty1 ty2 ls) =
PureTArrow <$> purify_typ ty1 <*> purify_typ ty2 <*> -ls
purify_tv (Unbound name level) = pure $ PureTVar name level
purify_tv (Link ty) = purify_typ ty
top_type_check :: Exp -> PureTyp
top_type_check exp = runST do
S{..} <- reset_type_variables
ty <- typeof S{..} [] exp
cycle_free S{..} ty
ty <- purify_typ ty
pure ty
force !x = pure x
main :: IO ()
main = do
let id = Lam "x" (Var "x")
let c1 = Lam "x" $ Lam "y" $ App (Var "x") (Var "y")
print $ top_type_check id
print $ top_type_check c1
print $ top_type_check $ Let "x" c1 (Var "x")
print $ top_type_check $ Let "y" (Lam "z" $ Var "z") (Var "y")
print $ top_type_check $ Lam "x" $ Let "y" (Lam "z" $ Var "z") (Var "y")
print $ top_type_check $ Lam "x" $ Let "y" (Lam "z" $ Var "z") $ App (Var "y") (Var "x")
join $ print <$> (try @SomeException $ force $ top_type_check $ Lam "x" $ App (Var "x") (Var "x"))
print "hello"
join $ print <$> (try @SomeException $ force $ top_type_check $ Let "x" (Var "x") (Var "x"))
join $ print <$> (try @SomeException $ force $ top_type_check $ Lam "y" $ App (Var "y") (Lam "z" $ App (Var "y") (Var "z")))
join $ print <$> (try @SomeException $ force $ top_type_check $ Lam "x" $ Lam "y" $ Lam "k" $ App (App (App (Var "k") (Var "x")) (Var "y")) (App (App (Var "k") (Var "y")) (Var "x")))
print $ top_type_check $ Let "id" id (App (Var "id") (Var "id"))
print $ top_type_check $ (Let "x" c1
(Let "y"
(Let "z" (App (Var "x") id) (Var "z"))
(Var "y")))
print $ top_type_check $ (Lam "x" (Lam "y"
(Let "x" (App (Var "x") (Var "y"))
(Lam "x" (App (Var "y") (Var "x"))))))
print $ top_type_check $ (Lam "x" (Let "y" (Var "x") (Var "y")))
print $ top_type_check $ (Lam "x" (Let "y" (Lam "z" (Var "x")) (Var "y")))
print $ top_type_check $ (Lam "x" (Let "y" (Lam "z" (App (Var "x") (Var "z"))) (Var "y")))
print $ top_type_check $ (Lam "x" (Lam "y"
(Let "x" (App (Var "x") (Var "y"))
(App (Var "x") (Var "y")))))
join $ (print <$>) $ try @SomeException $ force $ top_type_check $ (Lam "x" (Let "y" (Var "x") (App (Var "y") (Var "y"))))
print $ top_type_check $ (Lam "x" (Let "y"
(Let "z" (App (Var "x") id) (Var "z"))
(Var "y")))
PureTArrow (PureTVar "a" 1) (PureTVar "a" 1) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTVar "b" 1) (PureTVar "b" 1) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTVar "a" 1) (PureTArrow (PureTVar "c" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTVar "c" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})
Left occurs check
"hello"
Left Maybe.fromJust: Nothing
Left occurs check
Left occurs check
PureTArrow (PureTVar "c" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTVar "i" 1) (PureTVar "i" 1) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTArrow (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTVar "a" 1) (PureTVar "a" 1) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTVar "a" 1) (PureTArrow (PureTVar "c" 1) (PureTVar "a" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
PureTArrow (PureTArrow (PureTVar "b" 1) (PureTArrow (PureTVar "b" 1) (PureTVar "d" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "b" 1) (PureTVar "d" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})
Left occurs check
PureTArrow (PureTArrow (PureTArrow (PureTVar "b" 1) (PureTVar "b" 1) (Levels {level_old = 1, level_new = 1})) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment