Skip to content

Instantly share code, notes, and snippets.

@algal
Created December 20, 2024 06:46
Show Gist options
  • Save algal/b1ba21105df4f118a6b2c3722b4355ce to your computer and use it in GitHub Desktop.
Save algal/b1ba21105df4f118a6b2c3722b4355ce to your computer and use it in GitHub Desktop.
Shuffling in torch vs numpy-public.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"id": "3631be3b",
"cell_type": "markdown",
"source": "# Torch vs Numpy shuffling\n\nThe purpose of this notebook is\n\n1. to confirm Oskar's analysis of problems with `torch.randperm`.\n\n2. to show that it can be fixed by using numpy's default random number generator or the PCG64DXSM generator specifically.\n\n\nOskar's analysis implied that `torch.randperm` was not shuffling elements uniformly when the number of elements became very large, e.g., a billion or more ([discord link](https://discord.com/channels/1200111522916094103/1276053252571533313/1278328850803064865)). We confirmed this in the notebook below.\n\nSpecifically, this notebook confirms the following: If you start with 3 billion elements, tag them as belonging to ten \"deciles\" representing the first 10% of elements, the second 10% of elements, and so on, then shuffle all the elements, and then look at the first 10,000 elements, you do not find that the ten deciles are approximately uniformly distributed among the first 10,000 elements. Instead, `torch.randperm` is more likely to put the earlier elements (the lower deciles) into the first 10,000 elements of the shuffled collection. This corresponds to the bottom-left figure in Oskar's original post.\n\nWe then found that switching to numpy's random number generator, or to numpy's PCG64DXSM generator, resolves this issue, and seems to produces an approximately uniform distribution of deciles within the first 10,000 elements of the shuffled collection..\n\nWe noticed that OLMo encountered a similar issue and resolved it by using a specific numpy generator ([OLMo link](https://discord.com/channels/1200111522916094103/1276053252571533313/1278402521391562752) ).\n"
},
{
"metadata": {},
"id": "48ff4441",
"cell_type": "markdown",
"source": "## How uniform is torch's shuffling with randperm?"
},
{
"metadata": {
"trusted": false
},
"id": "816fad78",
"cell_type": "markdown",
"source": "Let's generate an array representing a shuffling of the integers from 0 to (3 * 10^9 - 1)"
},
{
"metadata": {
"trusted": true
},
"id": "cd483264",
"cell_type": "code",
"source": "import torch\n\nshuffled = torch.randperm(3 * 10**9)",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "76553ee0",
"cell_type": "code",
"source": "shuffled_interval = shuffled[:10_000]",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "f7042107",
"cell_type": "code",
"source": "shuffled_interval[:50]",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "tensor([1994940797, 2347954103, 1691579500, 262720464, 2748298413, 2222743612,\n 1653995, 2342555504, 1440586107, 1077345365, 668539175, 2788525165,\n 117518658, 699722858, 1290958150, 2665873790, 2552246067, 2283160201,\n 177004620, 31564517, 2704208229, 1047164862, 25139448, 2216018010,\n 1004277474, 2440298876, 1240531966, 326584590, 943397255, 34365751,\n 150611451, 129402432, 2867900352, 1395077156, 256310869, 1292414480,\n 209935101, 440610241, 848544906, 407409817, 2578392564, 1809067203,\n 1297671095, 1108743574, 1086617589, 1632128034, 681494780, 1207082789,\n 1001392368, 879276196])"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "e776543d",
"cell_type": "code",
"source": "[x.item() for x in shuffled_interval[:10]]",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "[1994940797,\n 2347954103,\n 1691579500,\n 262720464,\n 2748298413,\n 2222743612,\n 1653995,\n 2342555504,\n 1440586107,\n 1077345365]"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "c97d3956",
"cell_type": "code",
"source": "shuffled.shape[0]",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "3000000000"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"id": "82671979",
"cell_type": "markdown",
"source": "Define a decile function, which tells us which of the ten \"decile\" bins an element belonged to. That is, in the original unshuffled collection, was it in the first 10%, the second 10%, ... and so on."
},
{
"metadata": {
"trusted": false
},
"id": "d8528af6",
"cell_type": "code",
"source": "def decile(index, collection_size):\n \"Returns 1 to 10, for the decile of `index` within `collection_size`\"\n return 1 + int(index // (collection_size / 10))",
"execution_count": 6,
"outputs": []
},
{
"metadata": {},
"id": "525e11ae",
"cell_type": "markdown",
"source": "Verify the above function works as expected:"
},
{
"metadata": {
"trusted": false
},
"id": "6e6ec9cd",
"cell_type": "code",
"source": "n = 20\n[decile(x,n) for x in range(n)]",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10]"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "bde78acd",
"cell_type": "code",
"source": "del n",
"execution_count": 8,
"outputs": []
},
{
"metadata": {},
"id": "de916fa2",
"cell_type": "markdown",
"source": "Compute the deciles for the first 10,000 shuffled values"
},
{
"metadata": {
"trusted": false
},
"id": "ba80ed86",
"cell_type": "code",
"source": "deciles = [decile(x.item(),shuffled.shape[0]) for x in shuffled_interval]",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "4845849e",
"cell_type": "markdown",
"source": "Let's graph a histogram of the values in deciles."
},
{
"metadata": {
"trusted": false
},
"id": "d983a9eb",
"cell_type": "code",
"source": "%matplotlib inline\nimport matplotlib.pyplot as plt\n\nplt.figure(figsize=(10, 6))\nplt.hist(deciles, bins=range(1, 12), align='left', rwidth=0.8)\nplt.xlabel('Decile')\nplt.ylabel('Frequency')\nplt.title('Histogram of Deciles, with torch.randperm')\nplt.xticks(range(1, 11))\nplt.show()",
"execution_count": 11,
"outputs": [
{
"data": {
"image/png": "",
"text/plain": "<Figure size 1000x600 with 1 Axes>"
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"metadata": {},
"id": "cd811ebd",
"cell_type": "markdown",
"source": "The above shows that torch's shuffler is not producing a uniform shuffle over a collection of 3 B elements."
},
{
"metadata": {},
"id": "ac04d4fe",
"cell_type": "markdown",
"source": "## Now let's try numpy with the PCG64DXSM generator."
},
{
"metadata": {
"trusted": false
},
"id": "da392490",
"cell_type": "markdown",
"source": "Let's generate an array representing a shuffling of the integers from 0 to (3 * 10^9 - 1), but using this random number generator from numpy: numpy.random.PCG64DXSM"
},
{
"metadata": {
"trusted": false
},
"id": "359b2270",
"cell_type": "code",
"source": "import numpy as np\n\nrng = np.random.Generator(np.random.PCG64DXSM())\nshuffled = rng.permutation(3 * 10**9)\nshuffled = torch.from_numpy(shuffled)\nshuffled_interval = shuffled[:10_000]",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "12e73269",
"cell_type": "code",
"source": "deciles = [decile(x.item(),shuffled.shape[0]) for x in shuffled_interval]",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "c1e19f7c",
"cell_type": "code",
"source": "%matplotlib inline\nimport matplotlib.pyplot as plt\n\nplt.figure(figsize=(10, 6))\nplt.hist(deciles, bins=range(1, 12), align='left', rwidth=0.8)\nplt.xlabel('Decile')\nplt.ylabel('Frequency')\nplt.title('Histogram of Deciles with PCG64DXSM randomization')\nplt.xticks(range(1, 11))\nplt.show()",
"execution_count": 14,
"outputs": [
{
"data": {
"image/png": "",
"text/plain": "<Figure size 1000x600 with 1 Axes>"
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"metadata": {},
"id": "6c6ad989",
"cell_type": "markdown",
"source": "This is much closer to a uniform distribution."
},
{
"metadata": {},
"id": "3df62049",
"cell_type": "markdown",
"source": "## Now let's try with numpy's default randperm."
},
{
"metadata": {
"trusted": false
},
"id": "85f9db59",
"cell_type": "code",
"source": "import numpy as np\n\nrng = np.random.default_rng()\nshuffled = rng.permutation(3 * 10**9)\nshuffled = torch.from_numpy(shuffled)\nshuffled_interval = shuffled[:10_000]",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "02839455",
"cell_type": "code",
"source": "deciles = [decile(x.item(),shuffled.shape[0]) for x in shuffled_interval]",
"execution_count": 16,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "52b2a6b4",
"cell_type": "code",
"source": "%matplotlib inline\nimport matplotlib.pyplot as plt\n\nplt.figure(figsize=(10, 6))\nplt.hist(deciles, bins=range(1, 12), align='left', rwidth=0.8)\nplt.xlabel('Decile')\nplt.ylabel('Frequency')\nplt.title('Histogram of Deciles with numpy default rng')\nplt.xticks(range(1, 11))\nplt.show()",
"execution_count": 17,
"outputs": [
{
"data": {
"image/png": "",
"text/plain": "<Figure size 1000x600 with 1 Axes>"
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"metadata": {},
"id": "feb5266b",
"cell_type": "markdown",
"source": "## Version check"
},
{
"metadata": {
"trusted": false
},
"id": "cced2ac8",
"cell_type": "code",
"source": "%%aip 0\nGenerate code to report my version of torch",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "8ed42c35",
"cell_type": "code",
"source": "import torch\nprint(f\"PyTorch version: {torch.__version__}\")",
"execution_count": 18,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "PyTorch version: 2.4.0\n"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "34b45e3e",
"cell_type": "code",
"source": "print(f\"Numpy version: {np.__version__}\")",
"execution_count": 20,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Numpy version: 2.0.1\n"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "5d26a2d8",
"cell_type": "code",
"source": "import os\ncpu_count = os.cpu_count()\ntry:\n with open(\"/proc/meminfo\",'r') as mem:\n meminfo = next(mem)\nexcept:\n meminfo ='unknown'\nprint(f\"CPU: {cpu_count} logical cores\")\nprint(meminfo)",
"execution_count": 26,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "CPU: 16 logical cores\nMemTotal: 131177156 kB\n\n"
}
]
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.12.0",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "Shuffling in torch vs numpy-public.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment