Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Last active February 24, 2025 12:16
Show Gist options
  • Save YouJiacheng/393c90cbdc23b09d5688815ba382288b to your computer and use it in GitHub Desktop.
Save YouJiacheng/393c90cbdc23b09d5688815ba382288b to your computer and use it in GitHub Desktop.
from functools import partial
import jax
import jax.numpy as jnp
import optax
def poly(x: jnp.ndarray, w: jnp.ndarray):
assert w.shape == (3,)
w = w.astype(jnp.float32)
return w[0] * x + w[1] * x**3 + w[2] * x**5
def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray):
y = [x]
for w in w_seq:
y.append(poly(y[-1], w))
return y
def min_of_polys(x: jnp.ndarray, w_seq: jnp.ndarray):
y = jnp.full_like(x, jnp.inf)
for w in w_seq:
y = jnp.minimum(y, poly(x, w))
return y
@partial(jax.jit, static_argnums=(2, 3))
def optimize_w(
w_seq: jnp.ndarray,
lr: float,
n: int,
debug: bool = False,
):
def loss(w_seq: jnp.ndarray):
w_seq = w_seq.astype(jnp.bfloat16)
xs = (jnp.arange(2048) + 1) / 2048
*zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq)
y_max = jnp.amax(ys)
y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf))
diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3)
slope_xs = (jnp.arange(320) + 1) / 256
min_ps = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq)
min_slope = jnp.amin(min_ps / slope_xs)
z_max_seq = [jnp.amax(z) for z in zs]
max_next_excess = sum(
jnp.clip(poly(z + 1 / 16, w) - z, min=0) for z, w in zip(z_max_seq, w_seq)
)
obj_0 = ys[0] / y_max # larger is better
obj_1 = y_max # closer to 1 is better
obj_2 = jnp.log2(diff_ratio) # smaller is better
obj_3 = min_slope # larger is better
obj_4 = max_next_excess # smaller is better
objectives = (obj_0, obj_1, obj_2, obj_3, obj_4)
if debug:
jax.debug.print("{x}", x=objectives)
loss = -4.0 * obj_0
loss += 16.0 * jnp.square(obj_1 - 1)
loss += 2.0 * jnp.clip(obj_2, min=-10)
loss += -4.0 * jnp.clip(obj_3, max=1 / 2)
loss += 64.0 * obj_4
return loss, objectives
loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True)
optimizer = optax.chain(
optax.adam(learning_rate=lr),
optax.clip_by_global_norm(1.0),
)
opt_state = optimizer.init(w_seq)
def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _):
w_seq, opt_state = carry
(_, objectives), grad = loss_and_grad_fn(w_seq)
updates, opt_state = optimizer.update(grad, opt_state)
w_seq = optax.apply_updates(w_seq, updates)
return (w_seq, opt_state), objectives
(w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n)
return w_seq, objectives
def main():
BASE = 128
w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6)
for i in range(5):
w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000)
print(w_seq.astype(jnp.bfloat16) * BASE)
print(i, [obj[-1].item() for obj in objectives])
for i in range(5):
w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000)
print(w_seq.astype(jnp.bfloat16) * BASE)
print(i, [obj[-1].item() for obj in objectives])
for i in range(5):
w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000)
print(w_seq.astype(jnp.bfloat16) * BASE)
print(i, [obj[-1].item() for obj in objectives])
for i in range(20):
w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000)
print(w_seq.astype(jnp.bfloat16) * BASE)
print(i, [obj[-1].item() for obj in objectives])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment