Last active
March 14, 2025 14:54
-
-
Save gatlin/e81b8c572b2f284f1423 to your computer and use it in GitHub Desktop.
Simple neural network with backpropagation in Haskell, using Repa. Inspired by: http://iamtrask.github.io/2015/07/12/basic-python-network/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{- cabal: | |
build-depends: base, repa, repa-algorithms | |
-} | |
{- To run: | |
1. Install Haskell tools (https://www.haskell.org/ghcup/) | |
2. cabal run Main.hs -- +RTS -s -} | |
module Main where | |
import Prelude hiding (map, zipWith) | |
import Control.Monad (forM_) | |
import Data.Array.Repa hiding ((++)) | |
import Data.Array.Repa.Algorithms.Matrix (mmultS, transpose2S) | |
import Data.Array.Repa.Algorithms.Randomish (randomishDoubleArray) | |
import Data.IORef (newIORef, readIORef, modifyIORef') | |
type Matrix a = Array a DIM2 | |
type Two a = (a,a) -- Kept a type signature shorter | |
-- | Convenience wrapper for generating random-ish arrays | |
randomArray | |
:: Int {- ^ Rows -} | |
-> Int {- ^ Columns -} | |
-> Matrix U Double | |
randomArray rows cols = computeS $ map (\x -> 2 * x - 1) $ | |
randomishDoubleArray (Z :. rows :. cols) 0 1 100 | |
-- | Test input data. | |
x :: Matrix U Double | |
x = fromListUnboxed (Z:.4:.3) | |
[ 0, 0, 1 | |
, 0, 1, 1 | |
, 1, 0, 1 | |
, 1, 1, 1 ] | |
-- | Expected output. | |
y :: Matrix U Double | |
y = fromListUnboxed (Z:.4:.1) [ 0, 1, 1, 0 ] | |
-- | Train the synapses (weights) of a 3-layer network | |
train | |
:: Matrix U Double -- ^ Input matrix; | |
-> Matrix U Double -- ^ Expected output matrix; | |
-> Int -- ^ Number of iterations to run; | |
-> IO (Two (Matrix U Double)) -- ^ Weight synapses. | |
train _in _ex n = do | |
s0Ref <- newIORef $ randomArray 3 4 -- - Create two mutable references | |
s1Ref <- newIORef $ randomArray 4 1 -- / | |
forM_ [1..n] $ \j -> do | |
syn0 <- readIORef s0Ref | |
syn1 <- readIORef s1Ref | |
let l1 = computeS $ map (1/) (map (1+) (map exp (map ((-1)*) (mmultS _in syn0)))) | |
let l2 = map (1/) (map (1+) (map exp (map ((-1)*) (mmultS l1 syn1)))) | |
let l2Delta = computeS $ zipWith (*) (zipWith (-) _ex l2) | |
(zipWith (*) l2 (map (1-) l2)) | |
let l1Delta = computeS $ zipWith (*) (mmultS l2Delta (transpose2S syn1)) | |
(zipWith (*) l1 (map (1-) l1)) | |
modifyIORef' s1Ref $ \s1 -> computeS $ zipWith (+) s1 | |
(mmultS (transpose2S l1) l2Delta) | |
modifyIORef' s0Ref $ \s0 -> computeS $ zipWith (+) s0 | |
(mmultS (transpose2S _in) l1Delta) | |
syn0 <- readIORef s0Ref | |
syn1 <- readIORef s1Ref | |
return (syn0, syn1) | |
-- | Run a network with the given synapses and inputs | |
run :: Two (Matrix U Double) | |
-> Matrix U Double | |
-> Matrix U Double | |
run (syn0, syn1) _in = | |
let l1 = computeS $ map (1/) (map (1+) (map exp (map ((-1)*) (mmultS _in syn0)))) | |
in computeS $ map (1/) (map (1+) (map exp (map ((-1)*) (mmultS l1 syn1)))) | |
main :: IO () | |
main = do | |
syns <- train x y 60000 | |
let results = run syns x | |
putStrLn $ "Results: " ++ (show results) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment