Last active
August 5, 2024 21:54
-
-
Save jeffasante/f9c81eb51220a259becabbc039a5a4d9 to your computer and use it in GitHub Desktop.
Unraveling the mysteries of neural networks: A hands-on guide to building a micrograd from the ground up, exploring backpropagation in detail using RUST. Inspired by Karpathy's micrograd.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Authors: Jeff Asante | |
// Github: https://gist.github.com/jeffasante/ | |
use std::cell::RefCell; | |
use std::collections::HashSet; | |
use std::fmt::{Display, Formatter}; | |
// utils | |
fn exp(x: f64) -> f64 { | |
x.exp() | |
} | |
fn square_root(x: f64) -> f64 { | |
x * x | |
} | |
//// | |
impl<'a> Display for Value<'a> { | |
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { | |
write!(f, "{}: {}", self.label, self.data) | |
} | |
} | |
// sample print | |
// for value in &topo { | |
// println!("{}", value); | |
// } | |
//// | |
struct Value<'a> { | |
data: f64, | |
// grad: f64, | |
grad: RefCell<f64>, // Using RefCell for the grad field to allow interior mutability. | |
_backward: Option<Box<dyn FnMut() + 'a>>, | |
_prev: Vec<&'a Value<'a>>, // Stores references to previous nodes//Vec::new(), | |
_op: &'a str, // Stores the operation (optional) | |
label: &'a str, // Stores the label (optional) | |
} | |
// The impl block allows you to define methods associated with the Value struct. | |
impl<'a> Value<'a> { | |
fn new(data: f64, label: &'a str) -> Value<'a> { | |
Value { | |
data, | |
grad: RefCell::new(0.0), | |
_backward: None, // Placeholder for the backward function | |
_prev: Vec::new(), // Initialize with an empty vector | |
_op: "", // Placeholder for operation | |
label, | |
} | |
} | |
fn multiply(&'a self, other: &'a Value<'a>, label: &'a str) -> Value<'a> { | |
let new_data = self.data * other.data; | |
Value { | |
data: new_data, | |
grad: RefCell::new(0.0), | |
_backward: None, | |
_prev: vec![self, other], // This creates a new Vec (vector) that contains references to self and other. | |
_op: "*", | |
label, | |
} | |
} | |
fn add(&'a self, other: &'a Value<'a>, label: &'a str) -> Value<'a> { | |
let new_data = self.data + other.data; | |
Value { | |
data: new_data, | |
grad: RefCell::new(0.0), | |
_backward: None, | |
_prev: vec![self, other], // This creates a new Vec (vector) that contains references to self and other. | |
_op: "+", | |
label, | |
} | |
} | |
fn tanh(&self) -> Value<'a> { | |
let x = self.data; | |
let t = (exp(2.0 * x) - 1.0) / (exp(2.0 * x) + 1.0); | |
Value::new(t, "tanh") | |
} | |
fn backward(&self) { | |
// This is a method for performing backpropagation starting from this node. | |
// It computes a topological ordering of the nodes. | |
let mut topo: Vec<&Value> = Vec::new(); // initializes an empty vector to store the nodes in topological order. | |
let mut visited: HashSet<*const Value> = HashSet::new(); // initializes an empty HashSet to track visited nodes. Using raw pointers (*const Value) ensures each node is uniquely identified. | |
fn build_topo<'a, 'b>( | |
v: &'a Value<'a>, | |
visited: &'b mut HashSet<*const Value<'a>>, | |
topo: &'b mut Vec<&'a Value<'a>>, | |
) { | |
let v_ptr: *const Value = v; | |
if !visited.contains(&v_ptr) { | |
visited.insert(v_ptr); | |
for &child in &v._prev { | |
build_topo(child, visited, topo); | |
} | |
topo.push(v); | |
} | |
} | |
build_topo(self, &mut visited, &mut topo); | |
} | |
fn update_data(&mut self, delta: f64) { | |
self.data = self.data + delta; | |
} | |
fn update_grad(&self, delta: f64) { | |
*self.grad.borrow_mut() += delta; | |
} | |
} | |
fn build_topo<'a, 'b>( | |
v: &'a Value<'a>, | |
visited: &'b mut HashSet<*const Value<'a>>, | |
topo: &'b mut Vec<&'a Value<'a>>, | |
) { | |
let v_ptr: *const Value = v; | |
if !visited.contains(&v_ptr) { | |
visited.insert(v_ptr); | |
for &child in &v._prev { | |
build_topo(child, visited, topo); | |
} | |
topo.push(v); | |
} | |
} | |
fn main() { | |
// inputs x1,x2 | |
let x1 = Value::new(2.0, "x1"); | |
let x2 = Value::new(0.0, "x2"); | |
// inputs w1,w2 | |
let w1 = Value::new(-3.0, "w1"); | |
let w2 = Value::new(1.0, "w2"); | |
// # bias of the neuron | |
let b = Value::new(6.8813735870195432, "b"); | |
let x1w1 = x1.multiply(&w1, "x1*w1"); | |
let x2w2 = x2.multiply(&w2, "x2*w2"); | |
let x1w1x2w2 = x1w1.add(&x2w2, "x1*w1 + x2*w2"); | |
let n = x1w1x2w2.add(&b, "n"); | |
let o = n.tanh(); | |
o.backward(); | |
let mut topo: Vec<&Value> = Vec::new(); // initializes an empty vector to store the nodes in topological order. | |
let mut visited: HashSet<*const Value> = HashSet::new(); // initializes an empty HashSet to track visited nodes. Using raw pointers (*const Value) ensures each node is uniquely identified. | |
build_topo(&o, &mut visited, &mut topo); | |
o.update_grad(1.0); | |
o.backward(); | |
n.backward(); | |
b.backward(); | |
x1w1x2w2.backward(); | |
x2w2.backward(); | |
x1w1.backward(); | |
let grad_value = w1.data * x1w1.grad.borrow().clone(); | |
x1.update_grad(grad_value); | |
let grad_value = x1.data * x1w1.grad.borrow().clone(); | |
w1.update_grad(grad_value); | |
let grad_value = w2.data * x2w2.grad.borrow().clone(); | |
x2.update_grad(grad_value); | |
let grad_value = x2.data * x2w2.grad.borrow().clone(); | |
w2.update_grad(grad_value); | |
x1w1.update_grad(0.5); | |
x2w2.update_grad(0.5); | |
x1w1x2w2.update_grad(0.5); | |
b.update_grad(0.5); | |
n.update_grad(0.5); | |
o.update_grad(0.5); | |
println!("loss: {}", 1.0 - square_root(o.data)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment