Skip to content

Instantly share code, notes, and snippets.

@adrn
Created September 26, 2024 21:27
Show Gist options
  • Save adrn/17c4d9f829ba5cb31f24849c2022e40a to your computer and use it in GitHub Desktop.
Save adrn/17c4d9f829ba5cb31f24849c2022e40a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3 (ipykernel)","language":"python"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from dataclasses import dataclass\nfrom functools import partial\nfrom typing import Callable\n\nimport jax\nimport jax.numpy as jnp\n\njax.config.update(\"jax_enable_x64\", True)\nimport diffrax\nimport matplotlib.pyplot as plt","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:29.522209Z","iopub.execute_input":"2024-09-26T21:27:29.522420Z","iopub.status.idle":"2024-09-26T21:27:30.439956Z","shell.execute_reply.started":"2024-09-26T21:27:29.522394Z","shell.execute_reply":"2024-09-26T21:27:30.439660Z"}},"outputs":[],"execution_count":1},{"cell_type":"markdown","source":"First, an example of how JAX can be used to auto-differentiate a function.","metadata":{}},{"cell_type":"code","source":"def some_func(pars):\n return pars[\"a\"] ** 2\n\n\nsome_func({\"a\": 1.0})","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.441024Z","iopub.execute_input":"2024-09-26T21:27:30.441162Z","iopub.status.idle":"2024-09-26T21:27:30.445024Z","shell.execute_reply.started":"2024-09-26T21:27:30.441152Z","shell.execute_reply":"2024-09-26T21:27:30.444757Z"}},"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"1.0"},"metadata":{}}],"execution_count":2},{"cell_type":"code","source":"some_func_grad = jax.grad(some_func, argnums=0)\nsome_func_grad({\"a\": 1.0})","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.445440Z","iopub.execute_input":"2024-09-26T21:27:30.445526Z","iopub.status.idle":"2024-09-26T21:27:30.530005Z","shell.execute_reply.started":"2024-09-26T21:27:30.445517Z","shell.execute_reply":"2024-09-26T21:27:30.529620Z"}},"outputs":[{"execution_count":3,"output_type":"execute_result","data":{"text/plain":"{'a': Array(2., dtype=float64, weak_type=True)}"},"metadata":{}}],"execution_count":3},{"cell_type":"markdown","source":"Now we want to compute an orbit, or a trajectory, of a test particle in a gravitational potential. \n\nSee also the diffrax documentation here: https://docs.kidger.site/diffrax/usage/getting-started/","metadata":{}},{"cell_type":"code","source":"class PotentialHelper:\n def __init__(self, potential):\n self.potential = jax.jit(potential)\n self.potential_grad = jax.jit(jax.grad(self.potential, argnums=1))\n\n @partial(jax.jit, static_argnums=(0,))\n def w_dot(self, t, w):\n ndim = w.shape[0] // 2\n accel = -self.potential_grad(t, w[:ndim])\n return jnp.concatenate((w[ndim:], accel))\n\n @partial(jax.jit, static_argnums=(0,))\n def __call__(self, t, w, *args):\n return self.w_dot(t, w)\n\n\[email protected]\ndef constant_potential(t, x):\n return 0.0\n\n\nhelper = PotentialHelper(constant_potential)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.530852Z","iopub.execute_input":"2024-09-26T21:27:30.531003Z","iopub.status.idle":"2024-09-26T21:27:30.536114Z","shell.execute_reply.started":"2024-09-26T21:27:30.530992Z","shell.execute_reply":"2024-09-26T21:27:30.535661Z"}},"outputs":[],"execution_count":4},{"cell_type":"markdown","source":"Set up the ODE solver:","metadata":{}},{"cell_type":"code","source":"term = diffrax.ODETerm(helper)\nsolver = diffrax.Dopri5()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.536668Z","iopub.execute_input":"2024-09-26T21:27:30.536861Z","iopub.status.idle":"2024-09-26T21:27:30.540869Z","shell.execute_reply.started":"2024-09-26T21:27:30.536848Z","shell.execute_reply":"2024-09-26T21:27:30.540407Z"}},"outputs":[],"execution_count":5},{"cell_type":"markdown","source":"Tell diffrax how often to output the computed phase-space coordinates:","metadata":{}},{"cell_type":"code","source":"sol = diffrax.diffeqsolve(\n term,\n solver,\n t0=0.0,\n t1=10.0,\n dt0=0.1,\n y0=jnp.array([0.0, 0.0, 1.0, 0.0]),\n saveat=diffrax.SaveAt(ts=jnp.arange(0, 10, 1.0))\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.541458Z","iopub.execute_input":"2024-09-26T21:27:30.541568Z","iopub.status.idle":"2024-09-26T21:27:30.884211Z","shell.execute_reply.started":"2024-09-26T21:27:30.541557Z","shell.execute_reply":"2024-09-26T21:27:30.883853Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"sol.ts","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.885879Z","iopub.execute_input":"2024-09-26T21:27:30.886029Z","iopub.status.idle":"2024-09-26T21:27:30.889157Z","shell.execute_reply.started":"2024-09-26T21:27:30.886015Z","shell.execute_reply":"2024-09-26T21:27:30.888692Z"}},"outputs":[{"execution_count":7,"output_type":"execute_result","data":{"text/plain":"Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float64, weak_type=True)"},"metadata":{}}],"execution_count":7},{"cell_type":"code","source":"x, y, vz, vy = sol.ys.T\n\nfig, ax = plt.subplots()\nax.scatter(x, y, c=sol.ts)\nax.set(xlim=(-10, 10), ylim=(-10, 10))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-09-26T21:27:30.889600Z","iopub.execute_input":"2024-09-26T21:27:30.889700Z","iopub.status.idle":"2024-09-26T21:27:31.078722Z","shell.execute_reply.started":"2024-09-26T21:27:30.889689Z","shell.execute_reply":"2024-09-26T21:27:31.078203Z"}},"outputs":[{"execution_count":8,"output_type":"execute_result","data":{"text/plain":"[(-10.0, 10.0), (-10.0, 10.0)]"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<Figure size 432x432 with 1 Axes>","image/png":""},"metadata":{"image/png":{"width":440,"height":440}}}],"execution_count":8},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment