Created
June 4, 2021 20:10
-
-
Save Joshuaalbert/39e5e44a06cb00e7e154b37504e30fa1 to your computer and use it in GitHub Desktop.
Regression Test BFGS speed test against jax and jaxlib versions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def speed_test_jax(): | |
import numpy as np | |
from jax import jit, value_and_grad, random, numpy as jnp | |
from jax.scipy.optimize import minimize as minimize_jax | |
from scipy.optimize import minimize as minimize_np | |
import pylab as plt | |
from timeit import default_timer | |
import jax | |
JAX_VERSION = jax.__version__ | |
import jaxlib | |
JAXLIB_VERSION = jaxlib.__version__ | |
S = 3 | |
t_scipy_halfjax, t_scipy_jax, t_jax, t_numpy = [], [], [], [] | |
N_array = [2, 10, 50, 100, 200, 400] | |
for N in N_array: | |
print("Working on N={}".format(N)) | |
A = random.normal(random.PRNGKey(0), shape=(N, N)) | |
u = jnp.ones(N) | |
x0 = -2. * jnp.ones(N) | |
def f_prescale(x, u): | |
y = A @ x | |
dx = u - y | |
return jnp.sum(dx ** 2) + 0.1 * jnp.sum(jnp.abs(x)) | |
# Due to https://github.com/google/jax/issues/4594 we scale the loss | |
# so that scipy and jax linesearch perform similarly. | |
jac_norm = jnp.linalg.norm(value_and_grad(f_prescale)(x0, u)[1]) | |
jac_norm_np = np.array(jac_norm) | |
def f(x, u): | |
y = A @ x | |
dx = u - y | |
return (jnp.sum(dx ** 2) + 0.1 * jnp.sum(jnp.abs(x))) / jac_norm | |
def f_np(x, u): | |
y = A @ x | |
dx = u - y | |
return (np.sum(dx ** 2) + 0.1 * np.sum(np.abs(x))) / jac_norm_np | |
print("Testing scipy+numpy") | |
t0 = default_timer() | |
args = (np.array(x0), (np.array(u),)) | |
results_np = minimize_np(f_np, *args, method='BFGS') | |
for _ in range(S): | |
results_np = minimize_np(f_np, *args, method='BFGS') | |
t_numpy.append((default_timer() - t0) / S) | |
print("nfev", results_np.nfev, "njev", results_np.njev) | |
print("Time for scipy + numpy", t_numpy[-1]) | |
print("Testing scipy + jitted function and numeric grad") | |
@jit | |
def _f(x0, u): | |
return f(x0, u) | |
_f(x0, u).block_until_ready() | |
t0 = default_timer() | |
for _ in range(S): | |
results_np = minimize_np(_f, x0, (u,), method='BFGS') | |
t_scipy_halfjax.append((default_timer() - t0) / S) | |
print("nfev", results_np.nfev, "njev", results_np.njev) | |
print("Time for scipy + jitted function and numeric grad", t_scipy_halfjax[-1]) | |
print("Testing scipy + jitted function and grad") | |
@jit | |
def _f(x0, u): | |
v, g = value_and_grad(f)(x0, u) | |
return v, g | |
_f(x0, u)[1].block_until_ready() | |
t0 = default_timer() | |
for _ in range(S): | |
results_np = minimize_np(_f, x0, (u,), method='BFGS', jac=True) | |
t_scipy_jax.append((default_timer() - t0) / S) | |
print("nfev", results_np.nfev, "njev", results_np.njev) | |
print("Time for scipy + jitted function and grad", t_scipy_jax[-1]) | |
print("Testing pure JAX implementation") | |
@jit | |
def do_minimize_jax(x0, u): | |
results = minimize_jax(f, x0, args=(u,), method='BFGS') | |
return results.x | |
results_jax = minimize_jax(f, x0, args=(u,), method='BFGS') | |
print("JAX f(optimal)", results_jax.fun, "scipy+jax f(optimal)", results_np.fun) | |
do_minimize_jax(x0, u).block_until_ready() | |
t0 = default_timer() | |
for _ in range(S): | |
do_minimize_jax(x0, u).block_until_ready() | |
t_jax.append((default_timer() - t0) / S) | |
print("nfev", results_jax.nfev, "njev", results_jax.njev) | |
print("Time for pure JAX implementation", t_jax[-1]) | |
plt.figure(figsize=(8, 5)) | |
plt.plot(N_array, t_scipy_jax, label='scipy+jitted(func and grad)') | |
plt.plot(N_array, t_scipy_halfjax, label='scipy+jitted(func)') | |
plt.plot(N_array, t_jax, label='pure JAX') | |
plt.plot(N_array, t_numpy, label='scipy+numpy') | |
plt.yscale('log') | |
plt.legend() | |
plt.title(f"(jax: {JAX_VERSION}, jaxlib: {JAXLIB_VERSION}) Run time of BFGS on N-D Least squares + L1 regularisation.") | |
plt.ylabel('Time [s]') | |
plt.xlabel("N") | |
plt.savefig(f"speed_results_jax-{JAX_VERSION}_jaxlib-{JAXLIB_VERSION}.png") | |
# plt.show() | |
if __name__ == '__main__': | |
speed_test_jax() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
JAX_VERSIONS=(0.1.74 0.1.75 0.1.76 0.1.77 0.2.0 0.2.1 0.2.2 0.2.3 0.2.4 0.2.5 0.2.6 0.2.7 0.2.8 0.2.9 0.2.10 0.2.11 0.2.12 0.2.13) | |
JAXLIB_VERSIONS=(0.1.52 0.1.52 0.1.52 0.1.55 0.1.55 0.1.55 0.1.56 0.1.56 0.1.56 0.1.56 0.1.57 0.1.57 0.1.58 0.1.59 0.1.61 0.1.64 0.1.64 0.1.65) | |
conda create -n test_env python=3.8 | |
conda activate test_env | |
pip install matplotlib scipy numpy | |
for index in ${!JAX_VERSIONS[@]}; do | |
echo Running $((index+1))/${#JAX_VERSIONS[@]} with jax=="${JAX_VERSIONS[index]}" and jaxlib=="${JAXLIB_VERSIONS[index]}" | |
pip install --ignore-installed jax=="${JAX_VERSIONS[index]}" jaxlib=="${JAXLIB_VERSIONS[index]}" | |
python bfgs_speed_test.py | |
done |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment