Last active
January 18, 2024 02:30
-
-
Save pervognsen/17e637e877040d336ba3abc7a13ef8d5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reverse-mode automatic differentiation | |
import math | |
# d(-x) = -dx | |
def func_neg(x): | |
return -x, [-1] | |
# d(x + y) = dx + dy | |
def func_add(x, y): | |
return x + y, [1, 1] | |
# d(x - y) = dx - dy | |
def func_sub(x, y): | |
return x - y, [1, -1] | |
# d(x y) = y dx + x dy | |
def func_mul(x, y): | |
return x * y, [y, x] | |
# d(x / y) = d(x 1/y) = 1/y dx - x/y^2 dy | |
def func_div(x, y): | |
return x / y, [1/y, -x/(y*y)] | |
# d(cos(x)) = -sin(x) dx | |
def func_cos(x): | |
return cos(x), [-sin(x)] | |
# d(sin(x)) = cos(x) dx | |
def func_sin(x): | |
return sin(x), [cos(x)] | |
# d(exp(x)) = exp(x) dx | |
def func_exp(x): | |
exp_x = exp(x) | |
return exp_x, [exp_x] | |
# d(log(x)) = 1/x dx | |
def func_log(x): | |
return log(x), [1/x] | |
# d(x**y) = d(exp(log(x) y)) = x**y y/x dx + x**y log(y) dy | |
def func_pow(x, y): | |
pow_xy = x**y | |
return pow_xy, [x**(y-1) * y, pow_xy * log(y)] | |
def func_when(x, y, z): | |
return when(x, y, z), [0, x, 1-x] | |
def func_le(x, y): | |
return x <= y, [0, 0] | |
def func_ge(x, y): | |
return x >= y, [0, 0] | |
class State: | |
def __init__(self, value=0, weights=()): | |
self.value = value | |
self.weights = weights | |
class Node: | |
def __init__(self, func, args): | |
self.func = func | |
self.args = args | |
def __neg__(self): | |
return make_node(func_neg, self) | |
def __add__(self, other): | |
return make_node(func_add, self, other) | |
def __radd__(self, other): | |
return make_node(func_add, other, self) | |
def __sub__(self, other): | |
return make_node(func_sub, self, other) | |
def __rsub__(self, other): | |
return make_node(func_sub, other, self) | |
def __mul__(self, other): | |
return make_node(func_mul, self, other) | |
def __rmul__(self, other): | |
return make_node(func_mul, other, self) | |
def __pow__(self, other): | |
return make_node(func_pow, self, other) | |
def __rpow__(self, other): | |
return make_node(func_pow, other, self) | |
def __truediv__(self, other): | |
return make_node(func_div, self, other) | |
def __rtruediv__(self, other): | |
return make_node(func_div, other, self) | |
def __le__(self, other): | |
return make_node(func_le, self, other) | |
def __ge__(self, other): | |
return make_node(func_ge, self, other) | |
def evaluate(self, bindings={}): | |
states = {node: State(value) for node, value in bindings.items()} | |
def visit(node): | |
if node in states: | |
return states[node].value | |
value, weights = node.func(*(visit(arg) for arg in node.args)) | |
states[node] = State(value, weights) | |
return value | |
visit(self) | |
return states | |
def gradients(self, bindings={}): | |
states = self.evaluate(bindings) | |
gradients = {node: 0 for node in states} | |
gradients[self] = 1 | |
for node, state in reversed(list(states.items())): | |
gradient = gradients[node] | |
for arg, weight in zip(node.args, state.weights): | |
gradients[arg] += weight * gradient | |
return gradients | |
def memo(func): | |
cache = {} | |
def wrapped(*args): | |
if args not in cache: | |
value = cache[args] = func(*args) | |
else: | |
value = cache[args] | |
return value | |
wrapped.__name__ = func.__name__ | |
return wrapped | |
@memo | |
def const(value): | |
return make_node(lambda: (value, ())) | |
@memo | |
def make_node(func, *args): | |
return Node(func, [arg if isinstance(arg, Node) else const(arg) for arg in args]) | |
none = make_node(lambda: 0, ()) | |
class Var(Node): | |
def __init__(self, value=None): | |
super().__init__(self._func, ()) | |
self.value = value | |
def _func(self): | |
if self.value is None: | |
raise ValueError("Unassigned variable") | |
return self.value, () | |
def var(value=None): | |
return Var(value) | |
def wrap_unary(math_func, node_func): | |
def wrapper(x): | |
return make_node(node_func, x) if isinstance(x, Node) else math_func(x) | |
wrapper.__name__ = math_func.__name__ | |
return wrapper | |
cos = wrap_unary(math.cos, func_cos) | |
sin = wrap_unary(math.sin, func_sin) | |
exp = wrap_unary(math.exp, func_exp) | |
log = wrap_unary(math.log, func_log) | |
def when(x, y, z): | |
if isinstance(x, Node) or isinstance(y, Node) or isinstance(z, Node): | |
return make_node(func_when, x, y, z) | |
else: | |
return y if x else z | |
# Tests | |
x = var(2) | |
y = var(3) | |
f = exp(sin(x * y) / y) | |
x0 = x.value | |
y0 = y.value | |
gradients = f.gradients() | |
print(gradients[x]) # 0.8747797595113476 | |
print(exp(sin(x0*y0)/y0)*cos(x0*y0)) # 0.8747797595113477 | |
print(gradients[y]) # 0.6114716536871055 | |
print(exp(sin(x0*y0)/y0)*(x0*y0*cos(x0*y0) - sin(x0*y0))/(y0**2)) # 0.6114716536871057 | |
def hessian(node, *args): | |
gradients = node.gradients({arg: arg for arg in args}) | |
return {arg: gradients[arg].gradients() for arg in args} | |
gradients = hessian(f, x, y) | |
print(gradients[x][x]) # 1.603636510793541 | |
print(exp(sin(x0*y0)/y0)*(cos(x0*y0)**2 - y0*sin(x0*y0))) # 1.603636510793541 | |
g = when(x >= 0, x**2, -x**3) | |
print(g.gradients({x: 2})[x]) # 4 | |
print(g.gradients({x: -2})[x]) # -12 | |
print(g.gradients({x: x})[x].gradients({x: 1})[x]) # 2 | |
print(g.gradients({x: x})[x].gradients({x: -1})[x]) # 6 | |
class LinearModel: | |
def __init__(self, a=1, b=0): | |
self.a = var(a) | |
self.b = var(b) | |
def __call__(self, x): | |
return self.a*x + self.b | |
def loss(self, data): | |
return sum((self(x) - y)**2 for x, y in data) / len(data) | |
def train(self, data, rate): | |
gradients = self.loss(data).gradients() | |
self.a.value -= rate * gradients[self.a] | |
self.b.value -= rate * gradients[self.b] | |
import random | |
model = LinearModel() | |
n = 1000 | |
data = [(1 + 10*i/n, 2*(1 + 10*i/n) - 3) for i in range(n)] | |
random.shuffle(data) | |
for i in range(10): | |
for point in data: | |
model.train([point], 0.01) | |
print(model.a.value, model.b.value) # 1.999999999999998 -2.9999999999999893 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment