-
-
Save mamonu/6845035b0705f91c29624fa07d642e6a to your computer and use it in GitHub Desktop.
Fast weighted sampling using the alias method in numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import numpy.random as npr, numpy as np\nfrom numba import jit", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "@jit(nopython=True)\ndef sample(n, q, J, r1, r2):\n res = np.zeros(n, dtype=np.int32)\n lj = len(J)\n for i in range(n):\n kk = int(np.floor(r1[i]*lj))\n if r2[i] < q[kk]: res[i] = kk\n else: res[i] = J[kk]\n return res\n\nclass AliasSample():\n def __init__(self, probs):\n self.K=K= len(probs)\n self.q=q= np.zeros(K)\n self.J=J= np.zeros(K, dtype=np.int)\n\n smaller,larger = [],[]\n for kk, prob in enumerate(probs):\n q[kk] = K*prob\n if q[kk] < 1.0: smaller.append(kk)\n else: larger.append(kk)\n\n while len(smaller) > 0 and len(larger) > 0:\n small,large = smaller.pop(),larger.pop()\n J[small] = large\n q[large] = q[large] - (1.0 - q[small])\n if q[large] < 1.0: smaller.append(large)\n else: larger.append(large)\n\n def draw_one(self):\n K,q,J = self.K,self.q,self.J\n kk = int(np.floor(npr.rand()*len(J)))\n if npr.rand() < q[kk]: return kk\n else: return J[kk]\n\n def draw_n(self, n):\n r1,r2 = npr.rand(n),npr.rand(n)\n return sample(n,self.q,self.J,r1,r2)", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# some weights to do weighted sampling by\nprs = npr.random(30000)\nprs/=prs.sum()", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "a = AliasSample(prs)", | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit t = a.draw_n(5000)", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "172 µs ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit t = np.random.choice(len(prs), 5000, p=prs)", | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "988 µs ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "cum_prs = prs.cumsum()", | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit t = np.searchsorted(cum_prs, np.random.random(5000))", | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "640 µs ± 7.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.4", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"toc": { | |
"threshold": 4, | |
"number_sections": true, | |
"toc_cell": false, | |
"toc_window_display": false, | |
"toc_section_display": "block", | |
"sideBar": true, | |
"navigate_menu": true, | |
"moveMenuLeft": true, | |
"widenNotebook": false, | |
"colors": { | |
"hover_highlight": "#DAA520", | |
"selected_highlight": "#FFD700", | |
"running_highlight": "#FF0000", | |
"wrapper_background": "#FFFFFF", | |
"sidebar_border": "#EEEEEE", | |
"navigate_text": "#333333", | |
"navigate_num": "#000000" | |
}, | |
"nav_menu": { | |
"width": "252px", | |
"height": "12px" | |
} | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "Fast alias sampling with Numba", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment