Created
April 20, 2026 09:23
-
-
Save MikuroXina/40040069bfec0712753a1b461dd56ab1 to your computer and use it in GitHub Desktop.
Back propagation combinators with Python.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections.abc import Callable | |
| from typing import TypeVar | |
| import numpy as np | |
| from numpy.typing import ArrayLike | |
| T = TypeVar("T") | |
| Cont = Callable[[Callable[[ArrayLike], T]], T] | |
| Flow = Callable[[ArrayLike], Cont] | |
| Lens = Callable[[Flow], Flow] | |
| class Function: | |
| """ | |
| Functions that can the backward computation. | |
| ``` | |
| x ------|--[ forward ]-> y -| | |
| | V | |
| [ callback ] | [ next ] | |
| ^ V | | |
| |--- dx --[ backward ]<-- dy ---| | |
| ``` | |
| """ | |
| def __call__(self, next: Flow) -> Flow: | |
| def f1(x: ArrayLike) -> Cont: | |
| def f2(callback: Callable[[ArrayLike], T]) -> T: | |
| def setting(dy: ArrayLike) -> T: | |
| return callback(self.backward(x, dy)) | |
| return next(self.forward(x))(setting) | |
| return f2 | |
| return f1 | |
| def forward(self, x: ArrayLike) -> ArrayLike: | |
| raise NotImplementedError() | |
| def backward(self, x: ArrayLike, dy: ArrayLike) -> ArrayLike: | |
| raise NotImplementedError() | |
| def pipe(first: Lens, second: Lens) -> Lens: | |
| """ | |
| Pipes two functions as `first` then `second`. | |
| """ | |
| return lambda x: second(first(x)) | |
| def back_propagate( | |
| lens: Lens, loss: Callable[[ArrayLike], ArrayLike], input: ArrayLike | |
| ) -> ArrayLike: | |
| """ | |
| Does the back propagation for `lens` with the `loss` function and `input`. | |
| """ | |
| def id(dx: ArrayLike) -> ArrayLike: | |
| return dx | |
| return lens(lambda y: lambda callback: callback(loss(y)))(input)(id) | |
| def back_propagate_const( | |
| lens: Lens, loss_value: ArrayLike, input: ArrayLike | |
| ) -> ArrayLike: | |
| """ | |
| Does the back propagation for `lens` with the `loss_value` and `input`. | |
| """ | |
| return back_propagate(lens, lambda _: loss_value, input) | |
| def evaluate(lens: Lens, input: ArrayLike) -> ArrayLike | None: | |
| """ | |
| Evaluates `lens` with the `input` and returns the output. | |
| It will return `None` if `lens` has no output. | |
| """ | |
| return lens(lambda y: lambda _: y)(input)(lambda _: None) | |
| class Square(Function): | |
| def forward(self, x: ArrayLike) -> ArrayLike: | |
| return np.pow(x, 2) | |
| def backward(self, x: ArrayLike, dy: ArrayLike) -> ArrayLike: | |
| return 2 * np.multiply(x, dy) | |
| class Exp(Function): | |
| def forward(self, x: ArrayLike) -> ArrayLike: | |
| return np.exp(x) | |
| def backward(self, x: ArrayLike, dy: ArrayLike) -> ArrayLike: | |
| return np.multiply(np.exp(x), dy) | |
| if __name__ == "__main__": | |
| A = Square() | |
| B = Exp() | |
| C = Square() | |
| F = pipe(A, pipe(B, C)) | |
| x = np.array(0.5) | |
| y = evaluate(F, x) | |
| print(y) | |
| dy = np.array(1.0) | |
| dx = back_propagate_const(F, dy, x) | |
| print(dx) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment