Last active
June 28, 2017 09:50
-
-
Save rgommers/3d05325678af5e7d1fb5818e28b8fc9d to your computer and use it in GitHub Desktop.
A test of autograd for automatic differentiation of scipy.special functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from autograd import grad | |
import autograd.scipy.special as special | |
plt.style.use('ggplot') | |
x = np.linspace(-10, 10, num=1000) | |
y_j0 = special.j0(x) | |
y_j1 = special.j1(x) | |
order = 1 | |
y_jn = special.jn(order, x) | |
def gradvec(fun, x): | |
""" | |
grad() takes only scalar inputs, so use a loop for now (inefficient, but it | |
works fine) - no need for speed just yet. | |
""" | |
dfun_dx = np.empty_like(x) | |
for ix, val in enumerate(x): | |
dfun_dx[ix] = grad(fun)(val) | |
return dfun_dx | |
def gradvec2(fun, arg1, x): | |
"""Same as `gradvec` but for functions like jn.""" | |
dfun_dx = np.empty_like(x) | |
for ix, val in enumerate(x): | |
dfun_dx[ix] = grad(fun)(arg1, val) | |
return dfun_dx | |
dy_j0 = gradvec(special.j0, x) | |
dy_j1 = gradvec(special.j1, x) | |
# This gives a warning, seems the autograd jn implementation isn't correct (?) | |
dy_jn = gradvec2(special.jn, order, x) | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
ax.plot(x, y_j0, '-', color='C0', label='$j_0$') | |
ax.plot(x, dy_j0, '--', color='C0', label='$dj_0/dx$') | |
ax.plot(x, y_j1, '-', color='C1', label='$j_1$') | |
ax.plot(x, dy_j1, '--', color='C1', label='$dj_1/dx$') | |
ax.set_xlabel('x') | |
ax.set_ylabel('f(x), df(x)/dx') | |
if False: | |
# Note that grad(special.jn) seems broken, returns zeros) | |
ax.plot(x, y_jn, '-', color='C2', label='$j_{n,%i}$' % order) | |
ax.plot(x, dy_jn, '--', color='C2', label='$dj_{n,%i}/dx$' % order) | |
ax.legend(loc='upper right') | |
# Create a second figure for psi, the vertical scale is quite different | |
x2 = np.linspace(0.1, 10, num=1000) | |
y_psi = special.psi(x2) | |
dy_psi = gradvec(special.psi, x2) | |
fig2 = plt.figure() | |
ax2 = fig2.add_subplot(111) | |
ax2.plot(x2, y_psi, '-', color='C0', label='$\psi$') | |
ax2.plot(x2, dy_psi, '--', color='C0', label='$d\psi/dx$') | |
ax2.set_xlabel('x') | |
ax2.set_ylabel('$\psi(x), d\psi(x)/dx$') | |
ax2.legend() | |
fig.savefig('bessel.png') | |
fig2.savefig('psi.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment