Skip to content

Instantly share code, notes, and snippets.

@AlexanderFabisch
Created March 22, 2025 14:45
Show Gist options
  • Save AlexanderFabisch/e1314400d3dfd11aafec6f62c8e1b972 to your computer and use it in GitHub Desktop.
Save AlexanderFabisch/e1314400d3dfd11aafec6f62c8e1b972 to your computer and use it in GitHub Desktop.
arccos approximations
import jax
from jax.scipy.special import factorial
import jax.numpy as jnp
import matplotlib.pyplot as plt
def approx(x, n_terms):
# https://www.physicsforums.com/threads/what-is-the-best-formula-for-calculating-arccos-near-x-1.778045/
s = 0.0
terms = jnp.asarray([((2 ** (2 * n - 1)) * factorial(n) ** 2) / (n ** 2 * factorial(2 * n)) * x ** (2 * n) for n in range(1, n_terms + 1)])
return 0.5 * jnp.pi - jnp.sqrt(jnp.sum(terms, axis=0))
def arccos_approx_near_1(x):
# https://www.johndcook.com/blog/2022/09/06/inverse-cosine-near-1/
return jnp.sqrt(2) * jnp.sqrt(1 - x) * (1 + (1 - x) / 12 + (3 / 160) * (1 - x) ** 2)
grad_near_1 = jax.vmap(jax.grad(arccos_approx_near_1))
def arccos_approx_near_m1(x):
return jnp.pi - jnp.sqrt(2) * jnp.sqrt(1 + x) * (1 + (1 + x) / 12 + (3 / 160) * (1 + x) ** 2)
grad_near_m1 = jax.vmap(jax.grad(arccos_approx_near_m1))
grad_arccos = jax.vmap(jax.grad(jnp.arccos))
x = jnp.logspace(-18, 0, 20_000)
x = jnp.hstack((-x[::-1], x))
print(approx(x, 5))
print(grad_near_m1(x)[:10])
print(grad_near_1(x)[-10:])
print(grad_arccos(x)[:10])
print(grad_arccos(x)[-10:])
#plt.plot(x, grad_near_m1(x), label="grad_near_m1")
#plt.plot(x, grad_near_1(x), label="grad_near_1")
#plt.plot(x, grad_arccos(x), label="grad_arccos")
for n_terms in range(3, 15):
plt.plot(x, approx(x, n_terms), label=f"approx {n_terms}")
plt.plot(x, jnp.arccos(x), label="arccos", alpha=0.3, lw=10)
#plt.plot(x, arccos_approx_near_1(x), label="arccos_approx_near_1")
#plt.plot(x, arccos_approx_near_m1(x), label="arccos_approx_near_m1")
plt.legend(loc="best")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment