Skip to content

Instantly share code, notes, and snippets.

@dfm
Created March 2, 2021 23:37
Show Gist options
  • Save dfm/a2db466f46ab931947882b08b2f21558 to your computer and use it in GitHub Desktop.
Save dfm/a2db466f46ab931947882b08b2f21558 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "deterministic-op.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EXip378M6Sut",
"outputId": "26f2b7b5-9b01-4004-de0c-8d8b2cd309a8"
},
"source": [
"!python -m pip install -U aesara"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already up-to-date: aesara in /usr/local/lib/python3.7/dist-packages (2.0.1)\n",
"Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from aesara) (3.0.12)\n",
"Requirement already satisfied, skipping upgrade: scipy>=0.14 in /usr/local/lib/python3.7/dist-packages (from aesara) (1.4.1)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.9.1 in /usr/local/lib/python3.7/dist-packages (from aesara) (1.19.5)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"id": "KgWj5zyi6VbD",
"outputId": "6a414710-fe5b-4f7c-b89a-dbeff5b076d6"
},
"source": [
"import aesara\n",
"aesara.__version__"
],
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'2.0.1'"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uPQKvBgZ6ZF-"
},
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from aesara.graph.basic import Apply\n",
"from aesara.graph.op import Op\n",
"from aesara.link.jax.jax_dispatch import jax_funcify\n",
"\n",
"class JaxOp(Op):\n",
" __props__ = (\"jax_fn\",)\n",
"\n",
" def __init__(self, jax_fn, itypes, otypes):\n",
" self.jax_fn = jax_fn\n",
" self.itypes = itypes\n",
" self.otypes = otypes\n",
" super().__init__()\n",
"\n",
" def perform(self, node, inputs, outputs):\n",
" results = self.jax_fn(*(jnp.asarray(x) for x in inputs))\n",
" if len(outputs) == 1:\n",
" outputs[0][0] = np.asarray(results)\n",
" return\n",
" for i, r in enumerate(results):\n",
" outputs[i][0] = np.asarray(r)\n",
"\n",
"\n",
"\n",
"@jax_funcify.register(JaxOp)\n",
"def jax_funcify_JaxOp(op):\n",
" func = op.jax_fn\n",
" return func"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DEZN16PaFpOp"
},
"source": [
"import aesara\n",
"import aesara.tensor as aet\n",
"from aesara.compile.mode import Mode\n",
"from aesara.link.jax import JAXLinker\n",
"\n",
"def func(x):\n",
" return jnp.exp(x)\n",
"\n",
"op = JaxOp(jax.jit(func), [aet.fvector], [aet.fvector])\n",
"\n",
"x = aet.fvector()\n",
"y = op(x)\n",
"\n",
"jax_mode = Mode(JAXLinker())\n",
"aesara_jax_fn = aesara.function([x], [y], mode=jax_mode)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dqAeCEvv8VV6",
"outputId": "bc5652ff-9c0d-4fdc-c3e4-1764c8b29cff"
},
"source": [
"aesara_jax_fn(np.random.randn(5).astype(np.float32))"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[DeviceArray([0.23700568, 4.216926 , 2.1559253 , 1.9963377 , 0.7241312 ], dtype=float32)]"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4ekjy6AMPCJM"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
@peterroelants
Copy link

This is interesting, thank you for sharing.

I'm wondering about two things:

  1. What is the implication of converting from Aesara to Jax and back again. Would jnp.asarray and np.asarray implicate any memory overhead?
  2. Why call jax.jit on the function passed to JaxOp? For some reason I was assuming Aesara would compile down to Jax (in Jax Mode) and would take care of this. What compilation does Aesara provide?

@dfm
Copy link
Author

dfm commented Mar 3, 2021

@peterroelants: these questions are both moot if you're only using the jaxified version of the function. Perform is only called when evaluating the op using aesara. So this means that you could use this op as a deterministic using original PyMC3 or the Jax backend, and on the Jax backend this would reduce directly to just the jax function.

But to answer them directly:

  1. I don't think the asarray calls are strictly necessary, but I think that they don't introduce overhead because I think that would happen behind the scenes anyways, but I could well be wrong.
  2. Again the jit only matters if you also want to incorporate this into an aesara model that doesn't use jax otherwise. If you're using the jax backed, I don't think it would hurt to do this (?) but it's definitely not necessary in that case.

@bmorris3
Copy link

bmorris3 commented Jun 25, 2021

Thanks for this! One note I found while experimenting with this on aesara 2.0.12: jax_funcify_JaxOp seems to require an extra keyword arguments node and storage_map, so this tweak makes the code above work for me:

@jax_funcify.register(JaxOp)
def jax_funcify_JaxOp(op, *args, **kwargs):
    func = op.jax_fn
    return func

I hope that's sensible.

@dfm
Copy link
Author

dfm commented Jun 25, 2021

@bmorris3: Yeah - this interface has been a moving target so I haven't been following it too closely, so I'm not sure that I know enough to comment, but seems sensible enough :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment