Last active
July 7, 2025 20:11
-
-
Save L-as/23a71779ffd7486187903c75e73b9873 to your computer and use it in GitHub Desktop.
Haskell version of Oleg Kiselyov's sound_lazy.ml from https://okmij.org/ftp/ML/generalization.html rewritten mechanically, big differences being much of the code here is lazy, we use STRef, and global variables are put into a record.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GHC2021, NamedFieldPuns, LambdaCase, RebindableSyntax, RecordWildCards, BlockArguments, MagicHash #-} | |
{-# OPTIONS_GHC -Wall -Wno-name-shadowing -Wno-missing-signatures #-} | |
module Sound_lazy where | |
import Data.STRef | |
import Control.Monad.ST | |
import Prelude hiding (negate) | |
import Data.Char | |
import Control.Monad | |
import Data.Maybe | |
import Control.Exception | |
import GHC.Exts (reallyUnsafePtrEquality, isTrue#) | |
ifThenElse :: Bool -> a -> a -> a | |
ifThenElse True x _ = x | |
ifThenElse False _ y = y | |
type Varname = String | |
data Exp | |
= Var Varname | |
| App Exp Exp | |
| Lam Varname Exp | |
| Let Varname Exp Exp | |
type QName = String | |
type Level = Int | |
generic_level = 100000000 :: Level | |
marked_level = 0 - 1 :: Level | |
data Typ s | |
= TVar (STRef s (Tv s)) | |
| TArrow (Typ s) (Typ s) (STRef s Levels) | |
data Tv s | |
= Unbound String Level | |
| Link (Typ s) | |
data Levels | |
= Levels { level_old :: Level, level_new :: Level } | |
deriving Show | |
set_level_new ls level = modifySTRef ls (\ls -> ls{level_new = level}) | |
set_level_old ls level = modifySTRef ls (\ls -> ls{level_old = level}) | |
-- defining something like ! from OCaml | |
negate = readSTRef | |
(=:) = writeSTRef | |
repr :: Typ s -> ST s (Typ s) | |
repr t@(TVar tvr) = | |
-tvr >>= \case | |
Link t -> do | |
t <- repr t | |
tvr =: Link t | |
pure t | |
_ -> pure t | |
repr t = pure t | |
get_level :: Typ s -> ST s Level | |
get_level (TVar tvr) = -tvr >>= \case | |
Unbound _ l -> pure l | |
_ -> error "impossible" | |
get_level (TArrow _ _ ls) = level_new <$> -ls | |
data GlobalState s = S | |
{ gensym_counter :: STRef s Int | |
, current_level :: STRef s Level | |
, to_be_level_adjusted :: STRef s [Typ s] | |
} | |
incr ref = -ref >>= \n -> ref =: (n + 1) | |
decr ref = -ref >>= \n -> ref =: (n - 1) | |
gensym :: GlobalState s -> ST s String | |
gensym S{..} = do | |
n <- -gensym_counter | |
incr gensym_counter | |
pure $ if n < 26 | |
then [chr (ord 'a' + n)] | |
else 't' : show n | |
reset_type_variables = do | |
current_level <- newSTRef 1 | |
gensym_counter <- newSTRef 0 | |
to_be_level_adjusted <- newSTRef [] | |
pure S{..} | |
enter_level S{..} = incr current_level | |
leave_level S{..} = decr current_level | |
newvar :: GlobalState s -> ST s (Typ s) | |
newvar S{..} = do | |
name <- gensym S{..} | |
level <- -current_level | |
TVar <$> newSTRef (Unbound name level) | |
new_arrow :: GlobalState s -> Typ s -> Typ s -> ST s (Typ s) | |
new_arrow S{..} ty1 ty2 = do | |
level <- -current_level | |
ls <- newSTRef (Levels{level_old = level, level_new = level}) | |
pure (TArrow ty1 ty2 ls) | |
cycle_free :: GlobalState s -> Typ s -> ST s () | |
cycle_free S{..} (TVar tvr) = -tvr >>= \case | |
Unbound{} -> pure () | |
Link ty -> cycle_free S{..} ty | |
cycle_free S{..} (TArrow t1 t2 ls) = do | |
level <- level_new <$> -ls | |
when (level == marked_level) $ error "occurs check" | |
set_level_new ls marked_level | |
cycle_free S{..} t1 | |
cycle_free S{..} t2 | |
set_level_new ls level | |
update_level :: GlobalState s -> Level -> Typ s -> ST s () | |
update_level S{} l (TVar tvr) = do | |
-tvr >>= \case | |
Unbound n l' -> do | |
when (l' == generic_level) $ error "impossible" | |
when (l < l') $ tvr =: Unbound n l | |
_ -> error "impossible" | |
update_level S{..} l ty@(TArrow _ _ ls) = do | |
l' <- level_new <$> -ls | |
when (l' == generic_level) $ error "impossible" | |
when (l' == marked_level) $ error "occurs check" | |
when (l < l') do | |
level_old <$> -ls >>= \level_old -> when (l' == level_old) $ modifySTRef to_be_level_adjusted (ty :) | |
set_level_new ls l | |
ptrEqual :: a -> a -> Bool | |
ptrEqual !x !y = isTrue# (reallyUnsafePtrEquality x y) | |
unify :: GlobalState s -> Typ s -> Typ s -> ST s () | |
unify S{..} t1 t2 | |
| True <- ptrEqual t1 t2 = pure () | |
| otherwise = do | |
t1 <- repr t1 | |
t2 <- repr t2 | |
case (t1, t2) of | |
(TVar tv1, TVar tv2) -> (,) <$> -tv1 <*> -tv2 >>= \case | |
(Unbound _ l1, Unbound _ l2) -> | |
if tv1 == tv2 then pure () | |
else | |
if l1 > l2 | |
then tv1 =: Link t2 | |
else tv2 =: Link t1 | |
_ -> error "impossible" | |
(TVar tv, t') -> -tv >>= \case | |
Unbound _ l -> one_unbound l t' tv | |
_ -> error "impossible" | |
(t', TVar tv) -> -tv >>= \case | |
Unbound _ l -> one_unbound l t' tv | |
_ -> error "impossible" | |
(TArrow tyl1 tyl2 ll, TArrow tyr1 tyr2 lr) -> do | |
llevel_new <- level_new <$> -ll | |
rlevel_new <- level_new <$> -lr | |
when (llevel_new == marked_level || rlevel_new == marked_level) $ | |
error "cycle: occurs check" | |
let min_level = min llevel_new rlevel_new | |
set_level_new ll marked_level | |
set_level_new lr marked_level | |
unify_lev S{..} min_level tyl1 tyr1 | |
unify_lev S{..} min_level tyl2 tyr2 | |
set_level_new ll min_level | |
set_level_new lr min_level | |
where | |
one_unbound l t' tv = do | |
update_level S{..} l t' | |
tv =: Link t' | |
unify_lev S{..} l ty1 ty2 = do | |
ty1 <- repr ty1 | |
update_level S{..} l ty1 | |
unify S{..} ty1 ty2 | |
type Env s = [(Varname, Typ s)] | |
force_delayed_adjustments S{..} = do | |
old <- -to_be_level_adjusted | |
new <- foldl' (\acc ty -> join $ adjust_one <$> acc <*> pure ty) (pure []) old | |
to_be_level_adjusted =: new | |
where | |
loop acc level ty = | |
repr ty >>= \case | |
TVar tvr -> -tvr >>= \case | |
Unbound name l | l > level -> | |
tvr =: Unbound name level >> pure acc | |
_ -> pure acc | |
ty@(TArrow _ _ ls) -> level_new <$> -ls >>= \case | |
l | l == marked_level -> error "occurs check" | |
l -> do | |
when (l > level) $ set_level_new ls level | |
adjust_one acc ty | |
adjust_one acc ty@(TArrow ty1 ty2 ls) = | |
(,) <$> -ls <*> -current_level >>= \case | |
(Levels{..}, current_level) | level_old <= current_level -> pure (ty : acc) | |
(Levels{..}, _) | level_old == level_new -> pure acc | |
(Levels{..}, _) -> do | |
let level = level_new | |
set_level_new ls marked_level | |
acc <- loop acc level ty1 | |
acc <- loop acc level ty2 | |
set_level_new ls level | |
set_level_old ls level | |
pure acc | |
adjust_one _ _ = error "impossible" | |
gen :: GlobalState s -> Typ s -> ST s () | |
gen S{..} = \ty -> do | |
force_delayed_adjustments S{..} | |
loop ty | |
where | |
loop ty = do | |
current_level <- -current_level | |
repr ty >>= \case | |
TVar tvr -> -tvr >>= \case | |
Unbound name l | l > current_level | |
-> tvr =: Unbound name generic_level | |
_ -> pure () | |
TArrow ty1 ty2 ls -> level_new <$> -ls >>= \case | |
l | l > current_level -> do | |
ty1 <- repr ty1 | |
ty2 <- repr ty2 | |
loop ty1 | |
loop ty2 | |
l <- max <$> get_level ty1 <*> get_level ty2 | |
set_level_old ls l | |
set_level_new ls l | |
_ -> pure () | |
inst :: GlobalState s -> Typ s -> ST s (Typ s) | |
inst S{..} = \ty -> fst <$> loop [] ty | |
where | |
loop subst ty@(TVar tvr) = -tvr >>= \case | |
Unbound name l | l == generic_level -> | |
case lookup name subst of | |
Just tv -> | |
pure (tv, subst) | |
Nothing -> do | |
tv <- newvar S{..} | |
pure (tv, (name, tv) : subst) | |
Link ty -> loop subst ty | |
_ -> pure (ty, subst) | |
loop subst ty@(TArrow ty1 ty2 ls) = level_new <$> -ls >>= \case | |
l | l == generic_level -> do | |
(ty1, subst) <- loop subst ty1 | |
(ty2, subst) <- loop subst ty2 | |
pure (,) <*> new_arrow S{..} ty1 ty2 <*> pure subst | |
_ -> pure (ty, subst) | |
typeof :: GlobalState s -> Env s -> Exp -> ST s (Typ s) | |
typeof S{..} env (Var x) = inst S{..} (fromJust $ lookup x env) | |
typeof S{..} env (Lam x e) = do | |
ty_x <- newvar S{..} | |
ty_e <- typeof S{..} ((x, ty_x) : env) e | |
new_arrow S{..} ty_x ty_e | |
typeof S{..} env (App e1 e2) = do | |
ty_fun <- typeof S{..} env e1 | |
ty_arg <- typeof S{..} env e2 | |
ty_res <- newvar S{..} | |
join $ unify S{..} ty_fun <$> new_arrow S{..} ty_arg ty_res | |
pure ty_res | |
typeof S{..} env (Let x e e2) = do | |
enter_level S{..} | |
ty_e <- typeof S{..} env e | |
leave_level S{..} | |
gen S{..} ty_e | |
typeof S{..} ((x, ty_e) : env) e2 | |
data PureTyp | |
= PureTVar !String !Level | |
| PureTArrow !PureTyp !PureTyp !Levels | |
deriving Show | |
-- not present in OCaml version but necessary for Show instance | |
-- and also forces the type to get all exceptions thrown | |
purify_typ (TVar tv) = join $ purify_tv <$> -tv | |
purify_typ (TArrow ty1 ty2 ls) = | |
PureTArrow <$> purify_typ ty1 <*> purify_typ ty2 <*> -ls | |
purify_tv (Unbound name level) = pure $ PureTVar name level | |
purify_tv (Link ty) = purify_typ ty | |
top_type_check :: Exp -> PureTyp | |
top_type_check exp = runST do | |
S{..} <- reset_type_variables | |
ty <- typeof S{..} [] exp | |
cycle_free S{..} ty | |
ty <- purify_typ ty | |
pure ty | |
force !x = pure x | |
main :: IO () | |
main = do | |
let id = Lam "x" (Var "x") | |
let c1 = Lam "x" $ Lam "y" $ App (Var "x") (Var "y") | |
print $ top_type_check id | |
print $ top_type_check c1 | |
print $ top_type_check $ Let "x" c1 (Var "x") | |
print $ top_type_check $ Let "y" (Lam "z" $ Var "z") (Var "y") | |
print $ top_type_check $ Lam "x" $ Let "y" (Lam "z" $ Var "z") (Var "y") | |
print $ top_type_check $ Lam "x" $ Let "y" (Lam "z" $ Var "z") $ App (Var "y") (Var "x") | |
join $ print <$> (try @SomeException $ force $ top_type_check $ Lam "x" $ App (Var "x") (Var "x")) | |
print "hello" | |
join $ print <$> (try @SomeException $ force $ top_type_check $ Let "x" (Var "x") (Var "x")) | |
join $ print <$> (try @SomeException $ force $ top_type_check $ Lam "y" $ App (Var "y") (Lam "z" $ App (Var "y") (Var "z"))) | |
join $ print <$> (try @SomeException $ force $ top_type_check $ Lam "x" $ Lam "y" $ Lam "k" $ App (App (App (Var "k") (Var "x")) (Var "y")) (App (App (Var "k") (Var "y")) (Var "x"))) | |
print $ top_type_check $ Let "id" id (App (Var "id") (Var "id")) | |
print $ top_type_check $ (Let "x" c1 | |
(Let "y" | |
(Let "z" (App (Var "x") id) (Var "z")) | |
(Var "y"))) | |
print $ top_type_check $ (Lam "x" (Lam "y" | |
(Let "x" (App (Var "x") (Var "y")) | |
(Lam "x" (App (Var "y") (Var "x")))))) | |
print $ top_type_check $ (Lam "x" (Let "y" (Var "x") (Var "y"))) | |
print $ top_type_check $ (Lam "x" (Let "y" (Lam "z" (Var "x")) (Var "y"))) | |
print $ top_type_check $ (Lam "x" (Let "y" (Lam "z" (App (Var "x") (Var "z"))) (Var "y"))) | |
print $ top_type_check $ (Lam "x" (Lam "y" | |
(Let "x" (App (Var "x") (Var "y")) | |
(App (Var "x") (Var "y"))))) | |
join $ (print <$>) $ try @SomeException $ force $ top_type_check $ (Lam "x" (Let "y" (Var "x") (App (Var "y") (Var "y")))) | |
print $ top_type_check $ (Lam "x" (Let "y" | |
(Let "z" (App (Var "x") id) (Var "z")) | |
(Var "y"))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PureTArrow (PureTVar "a" 1) (PureTVar "a" 1) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTVar "b" 1) (PureTVar "b" 1) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTVar "a" 1) (PureTArrow (PureTVar "c" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTVar "c" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1}) | |
Left occurs check | |
"hello" | |
Left Maybe.fromJust: Nothing | |
Left occurs check | |
Left occurs check | |
PureTArrow (PureTVar "c" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTVar "i" 1) (PureTVar "i" 1) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTArrow (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "d" 1) (PureTVar "e" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTVar "a" 1) (PureTVar "a" 1) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTVar "a" 1) (PureTArrow (PureTVar "c" 1) (PureTVar "a" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "b" 1) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
PureTArrow (PureTArrow (PureTVar "b" 1) (PureTArrow (PureTVar "b" 1) (PureTVar "d" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1})) (PureTArrow (PureTVar "b" 1) (PureTVar "d" 1) (Levels {level_old = 1, level_new = 1})) (Levels {level_old = 1, level_new = 1}) | |
Left occurs check | |
PureTArrow (PureTArrow (PureTArrow (PureTVar "b" 1) (PureTVar "b" 1) (Levels {level_old = 1, level_new = 1})) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1})) (PureTVar "c" 1) (Levels {level_old = 1, level_new = 1}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment