Last active
November 12, 2023 11:56
-
-
Save EelcoHoogendoorn/6be31f076e1ea4d8d1ce197e0b0b3b63 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Minimal example of DLR (diagonal linear recurrent) layer in JAX | |
https://arxiv.org/pdf/2212.00768.pdf | |
""" | |
from typing import Any, Callable, Sequence, Tuple | |
from flax import linen | |
import jax | |
import jax.numpy as jnp | |
from numga.algebra.algebra import Algebra | |
from numga.backend.jax.context import JaxContext | |
class GADLR(linen.Module): | |
"""GA-DLR module.""" | |
ga: object | |
n_rotors: int | |
n_outputs: int | |
log_decay: Tuple[float, float] = (-4., 0.) | |
bias: bool = False | |
kernel_init: Callable[..., Any] = jax.nn.initializers.glorot_normal() | |
# kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() | |
@linen.compact | |
def __call__(self, x, u): | |
even = self.ga.subspace.even_grade() | |
bivec = self.ga.subspace.bivector() | |
key = jax.random.PRNGKey(1) | |
def make_log_scalar(key): | |
log_scalar = jax.random.uniform(key, shape=(self.n_rotors,), minval=self.log_decay[0], maxval=self.log_decay[1]) | |
return log_scalar | |
# log_scalar = self.param("log_scalar", make_log_scalar) | |
log_scalar = make_log_scalar(key) | |
def make_bivec(key) -> "Bivector": | |
b = jax.random.ball(key, shape=(self.n_rotors,), d=len(bivec)) | |
bn = jnp.linalg.norm(b, axis=1, keepdims=True) | |
return self.ga.multivector(bivec, b * bn) | |
# for the time being, lets leave these as fixed parameters? | |
# according to gateloop/mamba, want to condition these on u | |
# bivector = self.param("bivector", make_bivec) | |
bivector = make_bivec(key) | |
l: "Even" = (bivector * jnp.pi).exp() * jnp.exp(-jnp.exp(log_scalar)) | |
# wrap recurrent state x here in mv type and unpack again after product | |
# alternatively we would leak the multivector type in the broader codebase / arch | |
# may or may not be desirable? | |
A = lambda x: (l * self.ga.multivector(even, x)).values | |
B = linen.Dense( | |
len(even)*self.n_rotors, | |
name=f'hidden_B', | |
kernel_init=self.kernel_init, | |
use_bias=self.bias) | |
C = linen.Dense( | |
self.n_outputs, | |
name=f'hidden_C', | |
kernel_init=self.kernel_init, | |
use_bias=self.bias) | |
D = linen.Dense( | |
self.n_outputs, | |
name=f'hidden_D', | |
kernel_init=self.kernel_init, | |
use_bias=self.bias) | |
x = A(x) + B(u).reshape(x.shape) | |
y = C(x.flatten()) + D(u) | |
# y = C(x[:,0]) + D(u) | |
return x, y | |
def init_carry(self): | |
even = self.ga.subspace.even_grade() | |
c = jnp.zeros((self.n_rotors, len(even))) | |
return c | |
# return ga.multivector(c, subspace=even) | |
def test(): | |
key = jax.random.PRNGKey(1) | |
dlr = GADLR( | |
# algebra=Algebra.from_pqr(3, 0, 0), | |
ga=JaxContext(Algebra.from_pqr(3, 0, 0)), | |
n_outputs=32, | |
n_rotors=32, | |
bias=False | |
) | |
x = dlr.init_carry() | |
u = jnp.ones((1,)) | |
params = dlr.init(key, x, u) | |
# print(params) | |
# quit() | |
apply = jax.jit(dlr.apply) | |
r = [] | |
for i in range(100): | |
x, y = apply(params, x, u) | |
r.append(y) | |
r = jnp.array(r) | |
import matplotlib.pyplot as plt | |
plt.plot(r) | |
plt.show() | |
if __name__=='__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment