Skip to content

Instantly share code, notes, and snippets.

@MikuroXina
Created April 20, 2026 09:23
Show Gist options
  • Select an option

  • Save MikuroXina/40040069bfec0712753a1b461dd56ab1 to your computer and use it in GitHub Desktop.

Select an option

Save MikuroXina/40040069bfec0712753a1b461dd56ab1 to your computer and use it in GitHub Desktop.
Back propagation combinators with Python.
from collections.abc import Callable
from typing import TypeVar
import numpy as np
from numpy.typing import ArrayLike
T = TypeVar("T")
Cont = Callable[[Callable[[ArrayLike], T]], T]
Flow = Callable[[ArrayLike], Cont]
Lens = Callable[[Flow], Flow]
class Function:
"""
Functions that can the backward computation.
```
x ------|--[ forward ]-> y -|
| V
[ callback ] | [ next ]
^ V |
|--- dx --[ backward ]<-- dy ---|
```
"""
def __call__(self, next: Flow) -> Flow:
def f1(x: ArrayLike) -> Cont:
def f2(callback: Callable[[ArrayLike], T]) -> T:
def setting(dy: ArrayLike) -> T:
return callback(self.backward(x, dy))
return next(self.forward(x))(setting)
return f2
return f1
def forward(self, x: ArrayLike) -> ArrayLike:
raise NotImplementedError()
def backward(self, x: ArrayLike, dy: ArrayLike) -> ArrayLike:
raise NotImplementedError()
def pipe(first: Lens, second: Lens) -> Lens:
"""
Pipes two functions as `first` then `second`.
"""
return lambda x: second(first(x))
def back_propagate(
lens: Lens, loss: Callable[[ArrayLike], ArrayLike], input: ArrayLike
) -> ArrayLike:
"""
Does the back propagation for `lens` with the `loss` function and `input`.
"""
def id(dx: ArrayLike) -> ArrayLike:
return dx
return lens(lambda y: lambda callback: callback(loss(y)))(input)(id)
def back_propagate_const(
lens: Lens, loss_value: ArrayLike, input: ArrayLike
) -> ArrayLike:
"""
Does the back propagation for `lens` with the `loss_value` and `input`.
"""
return back_propagate(lens, lambda _: loss_value, input)
def evaluate(lens: Lens, input: ArrayLike) -> ArrayLike | None:
"""
Evaluates `lens` with the `input` and returns the output.
It will return `None` if `lens` has no output.
"""
return lens(lambda y: lambda _: y)(input)(lambda _: None)
class Square(Function):
def forward(self, x: ArrayLike) -> ArrayLike:
return np.pow(x, 2)
def backward(self, x: ArrayLike, dy: ArrayLike) -> ArrayLike:
return 2 * np.multiply(x, dy)
class Exp(Function):
def forward(self, x: ArrayLike) -> ArrayLike:
return np.exp(x)
def backward(self, x: ArrayLike, dy: ArrayLike) -> ArrayLike:
return np.multiply(np.exp(x), dy)
if __name__ == "__main__":
A = Square()
B = Exp()
C = Square()
F = pipe(A, pipe(B, C))
x = np.array(0.5)
y = evaluate(F, x)
print(y)
dy = np.array(1.0)
dx = back_propagate_const(F, dy, x)
print(dx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment