-
-
Save AlexMRuch/d2e4fc61cf93a6971dc62bb5e59fd43c to your computer and use it in GitHub Desktop.
Fast weighted sampling using the alias method in numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import numpy.random as npr, numpy as np\nfrom numba import jit", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "@jit(nopython=True)\ndef sample(n, q, J, r1, r2):\n res = np.zeros(n, dtype=np.int32)\n lj = len(J)\n for i in range(n):\n kk = int(np.floor(r1[i]*lj))\n if r2[i] < q[kk]: res[i] = kk\n else: res[i] = J[kk]\n return res\n\nclass AliasSample():\n def __init__(self, probs):\n self.K=K= len(probs)\n self.q=q= np.zeros(K)\n self.J=J= np.zeros(K, dtype=np.int)\n\n smaller,larger = [],[]\n for kk, prob in enumerate(probs):\n q[kk] = K*prob\n if q[kk] < 1.0: smaller.append(kk)\n else: larger.append(kk)\n\n while len(smaller) > 0 and len(larger) > 0:\n small,large = smaller.pop(),larger.pop()\n J[small] = large\n q[large] = q[large] - (1.0 - q[small])\n if q[large] < 1.0: smaller.append(large)\n else: larger.append(large)\n\n def draw_one(self):\n K,q,J = self.K,self.q,self.J\n kk = int(np.floor(npr.rand()*len(J)))\n if npr.rand() < q[kk]: return kk\n else: return J[kk]\n\n def draw_n(self, n):\n r1,r2 = npr.rand(n),npr.rand(n)\n return sample(n,self.q,self.J,r1,r2)", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# some weights to do weighted sampling by\nprs = npr.random(30000)\nprs/=prs.sum()", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "a = AliasSample(prs)", | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit t = a.draw_n(5000)", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "172 µs ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit t = np.random.choice(len(prs), 5000, p=prs)", | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "988 µs ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "cum_prs = prs.cumsum()", | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit t = np.searchsorted(cum_prs, np.random.random(5000))", | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "640 µs ± 7.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.4", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"toc": { | |
"threshold": 4, | |
"number_sections": true, | |
"toc_cell": false, | |
"toc_window_display": false, | |
"toc_section_display": "block", | |
"sideBar": true, | |
"navigate_menu": true, | |
"moveMenuLeft": true, | |
"widenNotebook": false, | |
"colors": { | |
"hover_highlight": "#DAA520", | |
"selected_highlight": "#FFD700", | |
"running_highlight": "#FF0000", | |
"wrapper_background": "#FFFFFF", | |
"sidebar_border": "#EEEEEE", | |
"navigate_text": "#333333", | |
"navigate_num": "#000000" | |
}, | |
"nav_menu": { | |
"width": "252px", | |
"height": "12px" | |
} | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "Fast alias sampling with Numba", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment