Last active
September 10, 2018 15:15
-
-
Save edsko/53c17347f4bac5828d4be0a099773f77 to your computer and use it in GitHub Desktop.
Applicative-only, spaceleak-free version of 'WriterT'
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE NoMonomorphismRestriction #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module Main where | |
import Data.Functor.Identity | |
import Data.Monoid | |
import Data.Traversable | |
import Control.Arrow (first, second) | |
import Control.Monad.Writer | |
{------------------------------------------------------------------------------- | |
Applicative-only, spaceleak-free version of 'WriterT' | |
-------------------------------------------------------------------------------} | |
newtype Collect w f a = Collect { runCollect :: f (a, w) } | |
deriving instance Show (f (a, w)) => Show (Collect w f a) | |
instance Functor f => Functor (Collect w f) where | |
fmap f (Collect bcs) = Collect (fmap (first f) bcs) | |
instance (Applicative f, Monoid w) => Applicative (Collect w f) where | |
pure x = Collect (pure (x, mempty)) | |
Collect fcs <*> Collect bcs = Collect (aux <$> fcs <*> bcs) | |
where | |
-- We force the evaluation of both logs, and tie the evaluation of their | |
-- concatenation of the pair also, just to be sure to be sure | |
aux :: (a -> b, w) -> (a, w) -> (b, w) | |
aux (f, !w) (a, !w') = let !w'' = mappend w w' in (f a, w'') | |
-- | Walk over a traversable data structure, collecting additional results | |
traverseCollect :: forall t f a b c. (Traversable t, Applicative f) | |
=> (a -> f (b, c)) -> t a -> f (t b, [c]) | |
traverseCollect f = runCollect . traverse f' | |
where | |
f' :: a -> Collect [c] f b | |
f' = Collect . fmap (second (:[])) . f | |
collect :: (Applicative f, Monoid w) => w -> Collect w f a -> Collect w f a | |
collect w a = Collect (pure (id, w)) <*> a | |
{------------------------------------------------------------------------------- | |
Usage example | |
-------------------------------------------------------------------------------} | |
-- `traverse` is like `modifyMVar_`: | |
-- | |
-- > modifyMVar_ :: MVar a -> (a -> IO a) -> IO () | |
-- > flip traverse :: t a -> (a -> f b) -> f (t b) | |
modify_ :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b) | |
modify_ = traverse | |
-- now here's the puzzle. can we define the equivalent of modifyMVar? | |
-- | |
-- modifyMVar :: MVar a -> (a -> IO (a, b)) -> IO b | |
-- flip ??? :: t a -> (a -> f (b, c)) -> f (t b, [c]) | |
modify :: forall t f a b c. (Traversable t, Applicative f) | |
=> (a -> f (b, c)) -> t a -> f (t b, [c]) | |
modify f = runCollect . traverse f' | |
where | |
f' :: a -> Collect [c] f b | |
f' = Collect . fmap (second (:[])) . f | |
{------------------------------------------------------------------------------- | |
Test | |
-------------------------------------------------------------------------------} | |
-- max residency: 31 kB | |
testCollect :: Collect (Sum Int) Identity () | |
testCollect = nTimes 10000000 (collect (Sum 1)) (pure ()) | |
-- max residency: 1.3 GB | |
testWriterT :: WriterT (Sum Int) Identity () | |
testWriterT = nTimesM 10000000 (\() -> tell (Sum 1)) () | |
main :: IO () | |
main = print testWriterT | |
{------------------------------------------------------------------------------- | |
Auxiliary | |
-------------------------------------------------------------------------------} | |
nTimes :: Int -> (a -> a) -> (a -> a) | |
nTimes 0 _ !a = a | |
nTimes n f !a = nTimes (n - 1) f (f a) | |
nTimesM :: Monad m => Int -> (a -> m a) -> (a -> m a) | |
nTimesM 0 _ !a = return a | |
nTimesM n f !a = f a >>= nTimesM (n - 1) f |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment