Last active
February 24, 2025 12:16
-
-
Save YouJiacheng/393c90cbdc23b09d5688815ba382288b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import jax | |
import jax.numpy as jnp | |
import optax | |
def poly(x: jnp.ndarray, w: jnp.ndarray): | |
assert w.shape == (3,) | |
w = w.astype(jnp.float32) | |
return w[0] * x + w[1] * x**3 + w[2] * x**5 | |
def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray): | |
y = [x] | |
for w in w_seq: | |
y.append(poly(y[-1], w)) | |
return y | |
def min_of_polys(x: jnp.ndarray, w_seq: jnp.ndarray): | |
y = jnp.full_like(x, jnp.inf) | |
for w in w_seq: | |
y = jnp.minimum(y, poly(x, w)) | |
return y | |
@partial(jax.jit, static_argnums=(2, 3)) | |
def optimize_w( | |
w_seq: jnp.ndarray, | |
lr: float, | |
n: int, | |
debug: bool = False, | |
): | |
def loss(w_seq: jnp.ndarray): | |
w_seq = w_seq.astype(jnp.bfloat16) | |
xs = (jnp.arange(2048) + 1) / 2048 | |
*zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq) | |
y_max = jnp.amax(ys) | |
y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf)) | |
diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3) | |
slope_xs = (jnp.arange(320) + 1) / 256 | |
min_ps = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq) | |
min_slope = jnp.amin(min_ps / slope_xs) | |
z_max_seq = [jnp.amax(z) for z in zs] | |
max_next_excess = sum( | |
jnp.clip(poly(z + 1 / 16, w) - z, min=0) for z, w in zip(z_max_seq, w_seq) | |
) | |
obj_0 = ys[0] / y_max # larger is better | |
obj_1 = y_max # closer to 1 is better | |
obj_2 = jnp.log2(diff_ratio) # smaller is better | |
obj_3 = min_slope # larger is better | |
obj_4 = max_next_excess # smaller is better | |
objectives = (obj_0, obj_1, obj_2, obj_3, obj_4) | |
if debug: | |
jax.debug.print("{x}", x=objectives) | |
loss = -4.0 * obj_0 | |
loss += 16.0 * jnp.square(obj_1 - 1) | |
loss += 2.0 * jnp.clip(obj_2, min=-10) | |
loss += -4.0 * jnp.clip(obj_3, max=1 / 2) | |
loss += 64.0 * obj_4 | |
return loss, objectives | |
loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True) | |
optimizer = optax.chain( | |
optax.adam(learning_rate=lr), | |
optax.clip_by_global_norm(1.0), | |
) | |
opt_state = optimizer.init(w_seq) | |
def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): | |
w_seq, opt_state = carry | |
(_, objectives), grad = loss_and_grad_fn(w_seq) | |
updates, opt_state = optimizer.update(grad, opt_state) | |
w_seq = optax.apply_updates(w_seq, updates) | |
return (w_seq, opt_state), objectives | |
(w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n) | |
return w_seq, objectives | |
def main(): | |
BASE = 128 | |
w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6) | |
for i in range(5): | |
w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000) | |
print(w_seq.astype(jnp.bfloat16) * BASE) | |
print(i, [obj[-1].item() for obj in objectives]) | |
for i in range(5): | |
w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000) | |
print(w_seq.astype(jnp.bfloat16) * BASE) | |
print(i, [obj[-1].item() for obj in objectives]) | |
for i in range(5): | |
w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000) | |
print(w_seq.astype(jnp.bfloat16) * BASE) | |
print(i, [obj[-1].item() for obj in objectives]) | |
for i in range(20): | |
w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000) | |
print(w_seq.astype(jnp.bfloat16) * BASE) | |
print(i, [obj[-1].item() for obj in objectives]) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment