Created
November 7, 2024 20:44
-
-
Save Per48edjes/3bc289815191c4348711d977e8fd85fb to your computer and use it in GitHub Desktop.
jane_street_puzzle_202411.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyNuALs/gGT6kVhEXnxzYael", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Per48edjes/3bc289815191c4348711d977e8fd85fb/jane_street_puzzle_202411.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Jane Street Puzzle (November 2024)\n", | |
"\n", | |
"" | |
], | |
"metadata": { | |
"id": "2zxXE9fHXHD6" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"from typing import Callable\n", | |
"\n", | |
"\n", | |
"def generate_point(is_blue: bool) -> np.array:\n", | |
" r1 = np.random.rand()\n", | |
" r2 = np.random.rand()\n", | |
"\n", | |
" while is_blue and ((2 * r1) < (r1 + r2) or (r1 + r2) < 1):\n", | |
" r1 = np.random.rand()\n", | |
" r2 = np.random.rand()\n", | |
"\n", | |
" return np.array([r1, r2])\n", | |
"\n", | |
"def plot_decorator(func: Callable) -> Callable:\n", | |
" def wrapper(*args, **kwargs):\n", | |
" is_satisfied, c, r, b = func(*args, **kwargs)\n", | |
" plt.figure(figsize=(6, 6))\n", | |
" plt.scatter(r[0], r[1], color='red', label='r')\n", | |
" plt.scatter(b[0], b[1], color='blue', label='b')\n", | |
" plt.scatter(c[0], c[1], color='black', label='c')\n", | |
" if is_satisfied:\n", | |
" plt.plot([r[0], c[0]], [r[1], c[1]], color='red', linestyle='--')\n", | |
" plt.plot([b[0], c[0]], [b[1], c[1]], color='blue', linestyle='--')\n", | |
" plt.xlim(0, 1)\n", | |
" plt.ylim(0, 1)\n", | |
" plt.xlabel('x')\n", | |
" plt.ylabel('y')\n", | |
" plt.title(f'Satisfied? {is_satisfied}')\n", | |
" plt.legend()\n", | |
" plt.grid(True)\n", | |
" plt.show()\n", | |
" return is_satisfied, c, r, b\n", | |
" return wrapper\n", | |
"\n", | |
"def check_satisfiability(b: np.array, r: np.array) -> tuple[bool, np.array]:\n", | |
" x = r - b\n", | |
" r_x_prime, b_x_prime = 1 - r[0], 1 - b[0]\n", | |
" x_x, x_y = x[0], x[1]\n", | |
"\n", | |
" r_y_prime = ((r_x_prime ** 2) - (b_x_prime ** 2) - (x_y ** 2)) / (2 * x_y)\n", | |
" r_prime = np.array([r_x_prime, r_y_prime])\n", | |
" c = r + r_prime\n", | |
"\n", | |
" return 0 < c[1] < 1, c, r, b\n", | |
"\n", | |
"def generate_sample(b: np.array, r: np.array):\n", | |
" is_satisfied, c, r, b = check_satisfiability(b, r)\n", | |
" return is_satisfied, c, r, b\n", | |
"\n", | |
"def simulation(n: int, plot: bool = False) -> float:\n", | |
" global generate_sample\n", | |
" total_satisfied = 0\n", | |
" if plot:\n", | |
" generate_sample = plot_decorator(generate_sample)\n", | |
" for _ in range(n):\n", | |
" b, r = generate_point(True), generate_point(False)\n", | |
" is_satisfied, _, _, _ = generate_sample(b, r)\n", | |
" if is_satisfied:\n", | |
" total_satisfied += 1\n", | |
"\n", | |
" return total_satisfied / n" | |
], | |
"metadata": { | |
"id": "yFjI5ncyFn_u" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def empirical_sampling_distribution(M: int, N: int) -> None:\n", | |
" sample_means = []\n", | |
" for _ in range(M):\n", | |
" sample_means.append(simulation(N))\n", | |
"\n", | |
" plt.figure(figsize=(10, 6))\n", | |
" plt.hist(sample_means, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')\n", | |
" plt.title(f'Empirical Sampling Distribution of the Sample Mean (M={M}, N={N})')\n", | |
" plt.xlabel('Sample Mean')\n", | |
" plt.ylabel('Density')\n", | |
" plt.grid(True)\n", | |
" plt.show()\n", | |
"\n", | |
" mean_of_sample_means = np.mean(sample_means)\n", | |
" print(f\"Mean of sample means: {mean_of_sample_means}\")" | |
], | |
"metadata": { | |
"id": "qvBcoF85LQuV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"N = 10_000\n", | |
"M = 1_000\n", | |
"\n", | |
"empirical_sampling_distribution(M, N)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 581 | |
}, | |
"id": "4u6iNwtqDm5A", | |
"outputId": "0aaa76a4-f74b-4156-eeae-9088b653c28f" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1000x600 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Mean of sample means: 0.4914361\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Generate a blue point\n", | |
"# blue_point = generate_point(True)\n", | |
"blue_point = np.array([.9, .5])\n", | |
"\n", | |
"# Determine the closest edge\n", | |
"distances = [blue_point[0], blue_point[1], 1 - blue_point[0], 1 - blue_point[1]]\n", | |
"closest_edge = np.argmin(distances)\n", | |
"\n", | |
"# Generate grid points\n", | |
"grid_size = 1000\n", | |
"x = np.linspace(0, 1, grid_size)\n", | |
"y = np.linspace(0, 1, grid_size)\n", | |
"X, Y = np.meshgrid(x, y)\n", | |
"\n", | |
"red_points = []\n", | |
"for i in range(grid_size):\n", | |
" for j in range(grid_size):\n", | |
" red_point = np.array([X[i,j], Y[i,j]])\n", | |
" is_satisfied, _, _, _ = generate_sample(blue_point, red_point)\n", | |
" if is_satisfied:\n", | |
" red_points.append(red_point)\n", | |
"red_points = np.array(red_points)\n", | |
"\n", | |
"# Plot\n", | |
"plt.figure(figsize=(6, 6))\n", | |
"plt.scatter(red_points[:, 0], red_points[:, 1], color='red', s=1, label='Possible Red Points')\n", | |
"plt.scatter(blue_point[0], blue_point[1], color='blue', label='Blue Point', s=10)\n", | |
"plt.xlim(0, 1)\n", | |
"plt.ylim(0, 1)\n", | |
"plt.xlabel(\"X\")\n", | |
"plt.ylabel(\"Y\")\n", | |
"plt.legend()\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 547 | |
}, | |
"id": "f-MldlgGP2r3", | |
"outputId": "3157cb56-262c-4eb3-cc85-9bf9cfc00c41" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 600x600 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This function gives the probability that a red point exists satisfying the problem's conditions given $\\vec{b}$:\n", | |
"\n", | |
"$$\n", | |
"f(b_x, b_y) = \\frac{\\pi r_0^2}{4} + \\frac{\\pi r_1^2}{4} - 2 \\int_0^{1 - b_x} \\sqrt{r_0^2 - x^2} + \\sqrt{r_1^2 - x^2} - 1 \\, dx\n", | |
"$$\n", | |
"\n", | |
"where $r_{0} = \\sqrt{(1-b_{x})^{2} + b_{y}^{2}}, r_{1} = \\sqrt{(1-b_{x})^{2} + (1-b_{y})^{2}}$.\n", | |
"\n", | |
"What we're looking for is the area under this contour over the \"wedge\" of the unit square given by $1 < b_{x} + b_{y} < 2b_{x}$.\n" | |
], | |
"metadata": { | |
"id": "wyb-1lqriQVS" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Below, I proceed by estimating the final probability by generating a sample of points, $T$, in the \"wedge\" and computing $\\frac{1}{|T|} \\sum_{(b_x, b_y) \\in T} f(b_x,b_y)$." | |
], | |
"metadata": { | |
"id": "jaG5f80EvOJV" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from mpmath import mp, quad\n", | |
"\n", | |
"# Set the precision to ensure accuracy in intermediate calculations\n", | |
"mp.dps = 20\n", | |
"\n", | |
"def f(b_x, b_y):\n", | |
" r0 = mp.sqrt((1 - b_x)**2 + b_y**2)\n", | |
" r1 = mp.sqrt((1 - b_x)**2 + (1 - b_y)**2)\n", | |
"\n", | |
" def overlap_integral(x):\n", | |
" return mp.sqrt(r0**2 - x**2) + mp.sqrt(r1**2 - x**2) - 1\n", | |
"\n", | |
" area_quarters = (mp.pi * r0**2 / 4) + (mp.pi * r1**2 / 4)\n", | |
" overlap_area = quad(overlap_integral, [0, 1 - b_x])\n", | |
"\n", | |
" return area_quarters - 2 * overlap_area\n" | |
], | |
"metadata": { | |
"id": "eiOxne9Eka1K" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"T = [generate_point(True) for _ in range(N)]\n", | |
"p = sum(f(*b) for b in T) / (N)\n", | |
"\n", | |
"print(f\"Empirical total probability over the region ({N} samples):\", p)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "DUm5Ly3EwFe2", | |
"outputId": "a2176a0d-4e92-4807-e560-242a5903667c" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Empirical total probability over the region (10000 samples): 0.49179896501722237376\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment