Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created January 23, 2026 16:00
Show Gist options
  • Select an option

  • Save ricardoV94/5c0936025f62cbb582a799f21c8a3d6d to your computer and use it in GitHub Desktop.

Select an option

Save ricardoV94/5c0936025f62cbb582a799f21c8a3d6d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"id": "63d9bc06-f486-4e34-ac5c-c9ce79853090",
"metadata": {
"ExecuteTime": {
"end_time": "2026-01-23T16:00:04.776661738Z",
"start_time": "2026-01-23T16:00:04.131802442Z"
}
},
"source": [
"import numpy as np\n",
"from numpy.random import Generator, SeedSequence, PCG64\n",
"\n",
"import pytensor.tensor as pt\n",
"from pytensor.graph import Op\n",
"from pytensor.tensor.type import TensorType\n",
"from pytensor.tensor.random.type import random_generator_type, RandomType\n",
"from pytensor import scan\n",
"from pytensor.gradient import null_type\n",
"\n",
"uint128 = TensorType(shape=(2,), dtype=\"uint64\")\n",
"\n",
"\n",
"class StateFromGenerator(Op):\n",
" itypes = [random_generator_type]\n",
" otypes = [uint128, uint128]\n",
"\n",
" def perform(self, node, inputs, outputs):\n",
" [generator] = inputs\n",
" \n",
" state_dict = generator.bit_generator.state[\"state\"]\n",
" s = state_dict[\"state\"]\n",
" i = state_dict[\"inc\"]\n",
" mask_64 = 0xFFFFFFFFFFFFFFFF\n",
" s_hi = (s >> 64) & mask_64\n",
" s_lo = s & mask_64\n",
" i_hi = (i >> 64) & mask_64\n",
" i_lo = i & mask_64\n",
" \n",
" outputs[0][0] = np.array([s_hi, s_lo], dtype=np.uint64)\n",
" outputs[1][0] = np.array([i_hi, i_lo], dtype=np.uint64)\n",
"\n",
" def L_op(self, inputs, outputs, output_gradients):\n",
" return [null_type()]\n",
"\n",
"\n",
"class GeneratorFromState(Op):\n",
" itypes = [uint128, uint128]\n",
" otypes = [random_generator_type]\n",
"\n",
" def perform(self, node, inputs, outputs, _seed_seq=SeedSequence(0)):\n",
" state_arr, inc_arr = inputs\n",
" bit_gen = PCG64(_seed_seq)\n",
" state = bit_gen.state # returns a copy\n",
" state[\"state\"] = {\n",
" \"state\": (int(state_arr[0]) << 64) | int(state_arr[1]),\n",
" \"inc\": (int(inc_arr[0]) << 64) | int(inc_arr[1]),\n",
" }\n",
" bit_gen.state = state\n",
" outputs[0][0] = Generator(bit_gen)\n",
"\n",
" \n",
" def L_op(self, inputs, outputs, output_gradients):\n",
" return [null_type(), null_type()]\n",
"\n",
" \n",
"state_from_generator = StateFromGenerator()\n",
"generator_from_state = GeneratorFromState()"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"id": "ef8e1142-58d5-4d8b-a68c-3bcdc1d8ceb3",
"metadata": {
"ExecuteTime": {
"end_time": "2026-01-23T16:00:04.834364760Z",
"start_time": "2026-01-23T16:00:04.778779458Z"
}
},
"source": [
"def standard_normal(rng):\n",
" return pt.random.normal(rng=rng).owner.outputs\n",
"\n",
"sigma = pt.scalar(name=\"sigma\")\n",
"rng = random_generator_type(\"rng\")\n",
"state, inc = state_from_generator(rng)\n",
"\n",
"def step(prev_state, inc, sigma):\n",
" rng = generator_from_state(prev_state, inc)\n",
" next_rng, x = standard_normal(rng)\n",
" return state_from_generator(next_rng)[0], x * sigma\n",
"\n",
"[states, xs] = scan(\n",
" step,\n",
" outputs_info=[state, None],\n",
" non_sequences=[inc, sigma],\n",
" n_steps=10,\n",
" return_updates=False\n",
")\n",
"\n",
"last_x = xs[-1]"
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "code",
"id": "7e9e9520-25df-4454-8d70-c523cc40c8f2",
"metadata": {
"ExecuteTime": {
"end_time": "2026-01-23T16:00:05.581777227Z",
"start_time": "2026-01-23T16:00:04.841802091Z"
}
},
"source": [
"last_x.eval({sigma: 1.0, rng: np.random.default_rng(3)})"
],
"outputs": [
{
"data": {
"text/plain": [
"array(3.32299952)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"cell_type": "code",
"id": "9745f5f2-7233-4ec5-8876-6fa7072c0633",
"metadata": {
"ExecuteTime": {
"end_time": "2026-01-23T16:00:05.923196309Z",
"start_time": "2026-01-23T16:00:05.646226514Z"
}
},
"source": [
"pt.grad(last_x, sigma).eval({sigma: 1.0, rng: np.random.default_rng(3)})"
],
"outputs": [
{
"data": {
"text/plain": [
"array(3.32299952)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"cell_type": "code",
"id": "2a74247a-1946-4514-ac58-1de6e42ed8ca",
"metadata": {
"ExecuteTime": {
"end_time": "2026-01-23T16:00:05.944416676Z",
"start_time": "2026-01-23T16:00:05.935340994Z"
}
},
"source": [],
"outputs": [],
"execution_count": 4
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment