Skip to content

Instantly share code, notes, and snippets.

@jeffasante
Last active August 5, 2024 21:54
Show Gist options
  • Save jeffasante/f9c81eb51220a259becabbc039a5a4d9 to your computer and use it in GitHub Desktop.
Save jeffasante/f9c81eb51220a259becabbc039a5a4d9 to your computer and use it in GitHub Desktop.
Unraveling the mysteries of neural networks: A hands-on guide to building a micrograd from the ground up, exploring backpropagation in detail using RUST. Inspired by Karpathy's micrograd.
// Authors: Jeff Asante
// Github: https://gist.github.com/jeffasante/
use std::cell::RefCell;
use std::collections::HashSet;
use std::fmt::{Display, Formatter};
// utils
fn exp(x: f64) -> f64 {
x.exp()
}
fn square_root(x: f64) -> f64 {
x * x
}
////
impl<'a> Display for Value<'a> {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "{}: {}", self.label, self.data)
}
}
// sample print
// for value in &topo {
// println!("{}", value);
// }
////
struct Value<'a> {
data: f64,
// grad: f64,
grad: RefCell<f64>, // Using RefCell for the grad field to allow interior mutability.
_backward: Option<Box<dyn FnMut() + 'a>>,
_prev: Vec<&'a Value<'a>>, // Stores references to previous nodes//Vec::new(),
_op: &'a str, // Stores the operation (optional)
label: &'a str, // Stores the label (optional)
}
// The impl block allows you to define methods associated with the Value struct.
impl<'a> Value<'a> {
fn new(data: f64, label: &'a str) -> Value<'a> {
Value {
data,
grad: RefCell::new(0.0),
_backward: None, // Placeholder for the backward function
_prev: Vec::new(), // Initialize with an empty vector
_op: "", // Placeholder for operation
label,
}
}
fn multiply(&'a self, other: &'a Value<'a>, label: &'a str) -> Value<'a> {
let new_data = self.data * other.data;
Value {
data: new_data,
grad: RefCell::new(0.0),
_backward: None,
_prev: vec![self, other], // This creates a new Vec (vector) that contains references to self and other.
_op: "*",
label,
}
}
fn add(&'a self, other: &'a Value<'a>, label: &'a str) -> Value<'a> {
let new_data = self.data + other.data;
Value {
data: new_data,
grad: RefCell::new(0.0),
_backward: None,
_prev: vec![self, other], // This creates a new Vec (vector) that contains references to self and other.
_op: "+",
label,
}
}
fn tanh(&self) -> Value<'a> {
let x = self.data;
let t = (exp(2.0 * x) - 1.0) / (exp(2.0 * x) + 1.0);
Value::new(t, "tanh")
}
fn backward(&self) {
// This is a method for performing backpropagation starting from this node.
// It computes a topological ordering of the nodes.
let mut topo: Vec<&Value> = Vec::new(); // initializes an empty vector to store the nodes in topological order.
let mut visited: HashSet<*const Value> = HashSet::new(); // initializes an empty HashSet to track visited nodes. Using raw pointers (*const Value) ensures each node is uniquely identified.
fn build_topo<'a, 'b>(
v: &'a Value<'a>,
visited: &'b mut HashSet<*const Value<'a>>,
topo: &'b mut Vec<&'a Value<'a>>,
) {
let v_ptr: *const Value = v;
if !visited.contains(&v_ptr) {
visited.insert(v_ptr);
for &child in &v._prev {
build_topo(child, visited, topo);
}
topo.push(v);
}
}
build_topo(self, &mut visited, &mut topo);
}
fn update_data(&mut self, delta: f64) {
self.data = self.data + delta;
}
fn update_grad(&self, delta: f64) {
*self.grad.borrow_mut() += delta;
}
}
fn build_topo<'a, 'b>(
v: &'a Value<'a>,
visited: &'b mut HashSet<*const Value<'a>>,
topo: &'b mut Vec<&'a Value<'a>>,
) {
let v_ptr: *const Value = v;
if !visited.contains(&v_ptr) {
visited.insert(v_ptr);
for &child in &v._prev {
build_topo(child, visited, topo);
}
topo.push(v);
}
}
fn main() {
// inputs x1,x2
let x1 = Value::new(2.0, "x1");
let x2 = Value::new(0.0, "x2");
// inputs w1,w2
let w1 = Value::new(-3.0, "w1");
let w2 = Value::new(1.0, "w2");
// # bias of the neuron
let b = Value::new(6.8813735870195432, "b");
let x1w1 = x1.multiply(&w1, "x1*w1");
let x2w2 = x2.multiply(&w2, "x2*w2");
let x1w1x2w2 = x1w1.add(&x2w2, "x1*w1 + x2*w2");
let n = x1w1x2w2.add(&b, "n");
let o = n.tanh();
o.backward();
let mut topo: Vec<&Value> = Vec::new(); // initializes an empty vector to store the nodes in topological order.
let mut visited: HashSet<*const Value> = HashSet::new(); // initializes an empty HashSet to track visited nodes. Using raw pointers (*const Value) ensures each node is uniquely identified.
build_topo(&o, &mut visited, &mut topo);
o.update_grad(1.0);
o.backward();
n.backward();
b.backward();
x1w1x2w2.backward();
x2w2.backward();
x1w1.backward();
let grad_value = w1.data * x1w1.grad.borrow().clone();
x1.update_grad(grad_value);
let grad_value = x1.data * x1w1.grad.borrow().clone();
w1.update_grad(grad_value);
let grad_value = w2.data * x2w2.grad.borrow().clone();
x2.update_grad(grad_value);
let grad_value = x2.data * x2w2.grad.borrow().clone();
w2.update_grad(grad_value);
x1w1.update_grad(0.5);
x2w2.update_grad(0.5);
x1w1x2w2.update_grad(0.5);
b.update_grad(0.5);
n.update_grad(0.5);
o.update_grad(0.5);
println!("loss: {}", 1.0 - square_root(o.data));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment