Last active
August 28, 2018 10:37
-
-
Save snowleopard/2dd93951cfd42e03aa04a4aa696ca029 to your computer and use it in GitHub Desktop.
Typed constant folding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs, DataKinds, TypeOperators #-} | |
{-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-} | |
-- This is an attempt to find a safer implementation for GHC constant folding algorithm | |
-- See https://ghc.haskell.org/trac/ghc/ticket/15569 | |
-- Shapes of expression trees: L stands for a literal, V for a variable | |
data Shape = L | V | Shape :+: Shape | Shape :*: Shape | |
-- Arithmetic expressions with shape annotations | |
data Expr s a b where | |
Lit :: Polarity -> a -> Expr L a b | |
Var :: Polarity -> b -> Expr V a b | |
Add :: Expr x a b -> Expr y a b -> Expr (x :+: y) a b | |
Mul :: Expr x a b -> Expr y a b -> Expr (x :*: y) a b | |
mapLeft :: (Expr x a b -> Expr y a b) -> Expr (op x (z :: Shape)) a b -> Expr (op y z) a b | |
mapLeft f (Add x z) = Add (f x) z | |
mapLeft f (Mul x z) = Mul (f x) z | |
mapRight :: (Expr x a b -> Expr y a b) -> Expr (op (z :: Shape) x) a b -> Expr (op z y) a b | |
mapRight f (Add z x) = Add z (f x) | |
mapRight f (Mul z x) = Mul z (f x) | |
-- We use polarity to encode subtraction | |
data Polarity = Positive | Negative | |
neg :: Expr s a b -> Expr s a b | |
neg expr = case expr of | |
Lit p a -> Lit (mirror p) a | |
Var p b -> Var (mirror p) b | |
Add x y -> Add (neg x) (neg y) | |
Mul x y -> Mul (neg x) y | |
where | |
mirror Positive = Negative | |
mirror Negative = Positive | |
getLit :: Num a => Expr L a b -> a | |
getLit (Lit Positive a) = a | |
getLit (Lit Negative a) = negate a | |
-- A few smart constructors | |
lit :: a -> Expr L a b | |
lit = Lit Positive | |
var :: b -> Expr V a b | |
var = Var Positive | |
add :: Expr x a b -> Expr y a b -> Expr (x :+: y) a b | |
add = Add | |
sub :: Expr x a b -> Expr y a b -> Expr (x :+: y) a b | |
sub x y = Add x (neg y) | |
mul :: Expr x a b -> Expr y a b -> Expr (x :*: y) a b | |
mul = Mul | |
-- Axioms of addition and multiplication | |
comm :: Expr (op (x :: Shape) y) a b -> Expr (op y x) a b | |
comm (Add x y) = Add y x | |
comm (Mul x y) = Mul y x | |
assoc1 :: Expr (op (x :: Shape) (op y z)) a b -> Expr (op (op x y) z) a b | |
assoc1 (Add x (Add y z)) = Add (Add x y) z | |
assoc1 (Mul x (Mul y z)) = Mul (Mul x y) z | |
assoc2 :: Expr (op (op (x :: Shape) y) z) a b -> Expr (op x (op y z)) a b | |
assoc2 (Add (Add x y) z) = Add x (Add y z) | |
assoc2 (Mul (Mul x y) z) = Mul x (Mul y z) | |
distr :: Expr (x :*: (y :+: z)) a b -> Expr ((x :*: y) :+: (x :*: z)) a b | |
distr (Mul x (Add y z)) = Add (Mul x y) (Mul x z) | |
-- The main constant folding step | |
eval :: Num a => Expr (op L L) a b -> Expr L a b | |
eval (Add x y) = Lit Positive (getLit x + getLit y) | |
eval (Mul x y) = Lit Positive (getLit x * getLit y) | |
-- Constant folding rules that are checked by the compiler. | |
-- Ideally these rules would live in a separate module seeing | |
-- `Expr` as an abstract data type with transformations `eval`, | |
-- `mapLeft`, `comm` etc. whose correctness is verified manually. | |
-- In this way, I think, it should be impossible to write an | |
-- incorrect constant folding rule. | |
r1 :: Num a => Expr (op L (op L x)) a b -> Expr (op L x) a b | |
r1 = mapLeft eval . assoc1 | |
r2 :: Num a => Expr (op L (op x L)) a b -> Expr (op L x) a b | |
r2 = r1 . mapRight comm | |
r3 :: Num a => Expr (op (op L x) L) a b -> Expr (op L x) a b | |
r3 = r1 . comm | |
r4 :: Num a => Expr (op (op x L) L) a b -> Expr (op L x) a b | |
r4 = r3 . mapLeft comm | |
r5 :: Num a => Expr (op (op L x) (op L y)) a b -> Expr (op L (op x y)) a b | |
r5 = assoc2 . mapLeft r3 . assoc1 | |
r6 :: Num a => Expr (op (op L x) (op y L)) a b -> Expr (op L (op x y)) a b | |
r6 = r5 . mapRight comm | |
r7 :: Num a => Expr (op (op x L) (op L y)) a b -> Expr (op L (op x y)) a b | |
r7 = r5 . mapLeft comm | |
r8 :: Num a => Expr (op (op x L) (op y L)) a b -> Expr (op L (op x y)) a b | |
r8 = r6 . mapLeft comm | |
r9 :: Num a => Expr (L :*: (L :+: x)) a b -> Expr (L :+: (L :*: x)) a b | |
r9 = mapLeft eval . distr | |
r10 :: Num a => Expr (L :*: (x :+: L)) a b -> Expr (L :+: (L :*: x)) a b | |
r10 = r9 . mapRight comm | |
r11 :: Expr (op x (op L y)) a b -> Expr (op L (op x y)) a b | |
r11 = assoc2 . mapLeft comm . assoc1 | |
r12 :: Expr (op x (op y L)) a b -> Expr (op L (op x y)) a b | |
r12 = r11 . mapRight comm | |
r13 :: Expr (op (op L x) y) a b -> Expr (op L (op x y)) a b | |
r13 = mapRight comm . r11 . comm | |
r14 :: Expr (op (op x L) y) a b -> Expr (op L (op x y)) a b | |
r14 = mapRight comm . r12 . comm |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment