Last active
April 28, 2024 18:42
-
-
Save jumbojets/7de6cb1db19f2d407f4ff3adeeb3b681 to your computer and use it in GitHub Desktop.
proof of concept autograd implementation in ocaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* TODO: multicore and/or gpu computation would be fun *) | |
(* TODO: mnist sample *) | |
module Matrix = struct | |
module FA = Float.Array | |
type t = { data : FA.t; cols : int } | |
let make m n value = | |
let data = FA.make (m * n) value in | |
{ data; cols = n } | |
let zeros m n = make m n 0. | |
let ones m n = make m n 1. | |
let null = zeros 0 0 | |
let randn m n = | |
let nelts = m * n in | |
let data = FA.init nelts (fun _ -> Random.float 2. -. 1.) in | |
{ data; cols = n } | |
let shape x = | |
let nelts = FA.length x.data in | |
(nelts / x.cols, x.cols) | |
let index x i j = (i * x.cols) + j | |
let reduce ~f init x = FA.fold_left f init x.data | |
let map ~f x = | |
let data = FA.map f x.data in | |
{ x with data } | |
let map2 ~f x y = | |
let data = FA.map2 f x.data y.data in | |
{ x with data } | |
let get x i j = FA.get x.data (index x i j) | |
let set x i j v = FA.set x.data (index x i j) v | |
let matmul x y = | |
let x_rows, x_cols = shape x in | |
let y_rows, y_cols = shape y in | |
assert (x_cols = y_rows); | |
let shared = x_cols in | |
let out = zeros x_rows y_cols in | |
for i = 0 to x_rows - 1 do | |
for j = 0 to y_cols - 1 do | |
let sum = ref 0. in | |
for k = 0 to shared - 1 do | |
sum := !sum +. (get x i k *. get y k j) | |
done; | |
set out i j !sum | |
done | |
done; | |
out | |
let transpose x = | |
let m, n = shape x in | |
let out = zeros n m in | |
for i = 0 to m - 1 do | |
for j = 0 to n - 1 do | |
set out j i (get x i j) | |
done | |
done; | |
out | |
end | |
type matrix = Matrix.t | |
module CompGraph = struct | |
type t = | |
| Const of matrix * matrix | |
| Input of string * matrix | |
| Add of t * t * matrix | |
| Neg of t * matrix | |
| Pow of t * float * matrix | |
| Exp of t * matrix | |
| Matmul of t * t * matrix | |
| Mean of t * matrix | |
| Clamp of t * float * float * matrix | |
| Dropout of t * float * matrix | |
| Normalize of t * matrix (* by L1 norm *) | |
(* convenience functions because i made variant constructors a bit | |
unweidly with the associated gradient matrix *) | |
let ( + ) x y = Add (x, y, Matrix.null) | |
let ( ~- ) x = Neg (x, Matrix.null) | |
let ( - ) x y = x + -y | |
let ( ^ ) x e = Pow (x, e, Matrix.null) | |
let ( @ ) x y = Matmul (x, y, Matrix.null) | |
let const x = Const (x, Matrix.null) | |
let input x = Input (x, Matrix.null) | |
let exp x = Exp (x, Matrix.null) | |
let mean x = Mean (x, Matrix.null) | |
let clamp x min max = Clamp (x, min, max, Matrix.null) | |
let dropout x r = Dropout (x, r, Matrix.null) | |
let normalize x = Normalize (x, Matrix.null) | |
(* this will either be null, forward evaluated matrix or gradient *) | |
let assoc_matrix = function | |
| Const (_, x) | |
| Input (_, x) | |
| Add (_, _, x) | |
| Neg (_, x) | |
| Pow (_, _, x) | |
| Exp (_, x) | |
| Matmul (_, _, x) | |
| Mean (_, x) | |
| Clamp (_, _, _, x) | |
| Dropout (_, _, x) | |
| Normalize (_, x) -> | |
x | |
let rec forward node inputs = | |
let forward node = forward node inputs in | |
match node with | |
| Const (x, _) -> Const (x, x) | |
| Input (x, _) -> Input (x, List.assoc x inputs) | |
| Add (x, y, _) -> | |
let x', y' = (forward x, forward y) in | |
Add (x', y', Matrix.map2 ~f:( +. ) (assoc_matrix x') (assoc_matrix y')) | |
| Neg (x, _) -> | |
let x' = forward x in | |
Neg (x', Matrix.map ~f:( ~-. ) (assoc_matrix x')) | |
| Pow (x, e, _) -> | |
let x' = forward x in | |
Pow (x', e, Matrix.map ~f:(fun x -> Float.pow x e) (assoc_matrix x')) | |
| Exp (x, _) -> | |
let x' = forward x in | |
Exp (x', Matrix.map ~f:Float.exp (assoc_matrix x')) | |
| Matmul (x, y, _) -> | |
let x', y' = (forward x, forward y) in | |
Matmul (x', y', Matrix.matmul (assoc_matrix x') (assoc_matrix y')) | |
| Mean (x, _) -> | |
let x' = forward x in | |
let m, n = Matrix.shape (assoc_matrix x') in | |
let sum = Matrix.reduce ~f:( +. ) 0. (assoc_matrix x') in | |
let mean = sum /. Float.of_int (m * n) in | |
Mean (x', Matrix.make 1 1 mean) | |
| Clamp (x, min, max, _) -> | |
let x' = forward x in | |
let eval = | |
assoc_matrix x' | |
|> Matrix.map ~f:(Float.min max) | |
|> Matrix.map ~f:(Float.max min) | |
in | |
Clamp (x', min, max, eval) | |
| Dropout (x, r, _) -> | |
let x' = forward x in | |
let eval = | |
Matrix.map | |
~f:(fun x -> if Random.float 1. > r then x else 0.) | |
(assoc_matrix x') | |
in | |
Dropout (x', r, eval) | |
| Normalize (x, _) -> | |
let x' = forward x in | |
let sum = Matrix.reduce ~f:( +. ) 0. (assoc_matrix x') in | |
let eval = Matrix.map ~f:(fun x -> x /. sum) (assoc_matrix x') in | |
Normalize (x', eval) | |
(* needs to be called after a forward pass *) | |
let backward root inputs = | |
let rec helper node grad = | |
match node with | |
| Const (x, _) -> Const (x, grad) | |
| Input (x, _) -> Input (x, grad) | |
| Add (x, y, _) -> Add (helper x grad, helper y grad, grad) | |
| Neg (x, _) -> | |
let grad' = Matrix.map ~f:( ~-. ) grad in | |
Neg (helper x grad', grad) | |
| Pow (x, e, _) -> | |
let grad' = Matrix.map ~f:(fun x -> e *. x) grad in | |
Pow (helper x grad', e, grad) | |
| Exp (x, _) -> | |
let grad' = Matrix.map ~f:Float.exp grad in | |
Exp (helper x grad', grad) | |
| Matmul (x, y, _) -> | |
(* i did this math on a whim with what i thought is right def of matmul gradient | |
...pretty sure its wrong *) | |
let x_mat, y_mat = (assoc_matrix x, assoc_matrix y) in | |
let grad_x = Matrix.matmul grad (Matrix.transpose y_mat) in | |
let grad_y = Matrix.matmul (Matrix.transpose x_mat) grad in | |
Matmul (helper x grad_x, helper y grad_y, grad) | |
| Mean (x, _) -> | |
let el = Matrix.get grad 0 0 in | |
let m, n = Matrix.shape (assoc_matrix x) in | |
let nelts = m * n in | |
let grad' = Matrix.make m n (el /. Float.of_int nelts) in | |
Mean (helper x grad', grad) | |
| Clamp (x, min, max, eval) -> | |
let grad' = | |
Matrix.map2 ~f:(fun g e -> if e > 0. then g else 0.) grad eval | |
in | |
Clamp (helper x grad', min, max, grad) | |
| Dropout (x, r, eval) -> | |
let grad' = | |
Matrix.map2 ~f:(fun g e -> if e > 0. then g else 0.) grad eval | |
in | |
Dropout (helper x grad', r, grad) | |
| Normalize (x, eval) -> | |
let first_unnorm = Matrix.get (assoc_matrix x) 0 0 in | |
let first_norm = Matrix.get eval 0 0 in | |
let norm_factor = first_norm /. first_unnorm in | |
let grad' = Matrix.map ~f:(Float.mul norm_factor) grad in | |
Normalize (helper x grad', grad) | |
in | |
let root_rows, root_cols = assoc_matrix root |> Matrix.shape in | |
let zero_grad = Matrix.ones root_rows root_cols in | |
helper root zero_grad | |
(* will also clear associated matrix *) | |
let apply_on_children f = function | |
| Const (x, _) -> const x | |
| Input (x, _) -> input x | |
| Add (x, y, _) -> f x + f y | |
| Neg (x, _) -> -f x | |
| Pow (x, e, _) -> f x ^ e | |
| Exp (x, _) -> exp (f x) | |
| Matmul (x, y, _) -> f x @ f y | |
| Mean (x, _) -> mean (f x) | |
| Clamp (x, min, max, _) -> clamp (f x) min max | |
| Dropout (x, r, _) -> dropout (f x) r | |
| Normalize (x, _) -> normalize (f x) | |
(* for fun, see if we can optimize graphs and simplify them. | |
there are more that i could do. *) | |
let compile node = | |
let rec pass = function | |
(* eagerly evaluate when all arguments are all const *) | |
| (Add (Const _, Const _, _) as node) | |
| (Neg (Const _, _) as node) | |
| (Pow (Const _, _, _) as node) | |
| (Exp (Const _, _) as node) | |
| (Matmul (Const _, Const _, _) as node) | |
| (Mean (Const _, _) as node) | |
| (Clamp (Const _, _, _, _) as node) | |
| (Normalize (Const _, _) as node) -> | |
const (forward node [] |> assoc_matrix) | |
(* dropout with rate 0 is identity *) | |
| Dropout (x, 0., _) -> x | |
(* pow x 1. = x *) | |
| Pow (x, 1., _) -> x | |
(* (-(-x)) = x *) | |
| Neg (Neg (x, _), _) -> x | |
(* no optimizations at this level *) | |
| node -> apply_on_children pass node | |
in | |
let rec apply_until_idempotent f x = | |
let x' = f x in | |
if x' = x then x' else apply_until_idempotent f x' | |
in | |
apply_until_idempotent pass node | |
end | |
module MlOps = struct | |
module CG = CompGraph | |
open CG | |
let relu x = clamp x 0. Float.infinity | |
let linear a b ?(dropout = 0.) x = (CG.dropout a dropout @ x) + b |> relu | |
let softmax x = x |> exp |> normalize | |
let mse y y' = (y - y') ^ 2. |> mean | |
let rec gd ?(lr = 0.05) = function | |
(* only perform update step on consts *) | |
| Const (x, g) -> const (Matrix.map2 ~f:(fun x g -> x -. (lr *. g)) x g) | |
(* pass through everything else *) | |
| node -> apply_on_children (gd ~lr) node | |
end | |
open CompGraph | |
open MlOps | |
let init_nn l1 l2 l3 l4 input = | |
input | |
|> linear | |
(const (Matrix.randn l2 l1)) | |
(const (Matrix.randn l2 1)) | |
~dropout:0.01 | |
|> linear | |
(const (Matrix.randn l3 l2)) | |
(const (Matrix.randn l3 1)) | |
~dropout:0.01 | |
|> linear (const (Matrix.randn l4 l3)) (const (Matrix.randn l4 1)) | |
|> softmax | |
let () = | |
Random.self_init (); | |
let inputs = [ ("x", Matrix.randn 20 1); ("y", Matrix.randn 10 1) ] in | |
let nn = input "x" |> init_nn 20 50 50 10 |> compile in | |
let optim = mse (input "y") nn |> compile in | |
let _y' = forward nn inputs |> assoc_matrix in | |
let optim = forward optim inputs in | |
let err = assoc_matrix optim in | |
Printf.printf "Error: %f\n" (Matrix.get err 0 0); | |
let gradients = backward optim inputs in | |
let _optim = gd ~lr:0.025 gradients in | |
() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment