Created
December 26, 2024 06:41
-
-
Save keisukefukuda/b2d265d5b1574e11ac10f3a69368dde1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from abc import ABC, abstractmethod\n", | |
"import random\n", | |
"import math\n", | |
"\n", | |
"\n", | |
"# 離散確率分布の抽象クラス(現在は1次元のみ)\n", | |
"class DiscreteProb(ABC):\n", | |
" @abstractmethod\n", | |
" def pmf(self, x: int):\n", | |
" # 確率量子関数\n", | |
" pass\n", | |
"\n", | |
" @abstractmethod\n", | |
" def __call__(self):\n", | |
" # 確率変数\n", | |
" pass\n", | |
"\n", | |
"\n", | |
"class Bern(DiscreteProb):\n", | |
" def __init__(self, mu: float):\n", | |
" self.mu = mu\n", | |
"\n", | |
" def pmf(self, x: int):\n", | |
" if x == 0:\n", | |
" return 1 - self.mu\n", | |
" elif x == 1:\n", | |
" return self.mu\n", | |
" else:\n", | |
" return 0\n", | |
"\n", | |
" def __call__(self):\n", | |
" return 1 if random.random() < self.mu else 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# サンプリングによる期待値計算とエントロピー計算\n", | |
"\n", | |
"\n", | |
"class SimpleSampler(object):\n", | |
" def __init__(self):\n", | |
" pass\n", | |
"\n", | |
" def __call__(self, p, L: int, f=None) -> float:\n", | |
" # 単純なサンプリングを用いて f(x) の期待値を計算する\n", | |
" # # p: 確率分布\n", | |
" # # f: 関数\n", | |
" # L: サンプル数\n", | |
" if f is None:\n", | |
" f = lambda x: x\n", | |
" return sum([f(p()) for _ in range(L)]) / L\n", | |
"\n", | |
"\n", | |
"class Entropy(object):\n", | |
" def __init__(self, sampler=SimpleSampler()):\n", | |
" self.sampler = sampler\n", | |
"\n", | |
" def __call__(self, p, n_samples: int = 10000) -> float:\n", | |
" return -1 * self.sampler(p, n_samples, lambda x: math.log(p.pmf(x)))\n", | |
"\n", | |
"\n", | |
"class KLDiv(object):\n", | |
" def __init__(self, sampler=SimpleSampler()):\n", | |
" self.sampler = sampler\n", | |
"\n", | |
" def __call__(self, *, p, q, L: int = 10000) -> float:\n", | |
" # p(X) と q(X) のKL距離\n", | |
" # KL[q(x) || p(x)]]\n", | |
" # を計算する\n", | |
" return -1 * self.sampler(\n", | |
" q, f=lambda x: q.pmf(x) * log(p.pmf(x) / q.pmf(x)), L=n_samples\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"近似値:0.63642175\n", | |
"理論値:0.63651417\n" | |
] | |
} | |
], | |
"source": [ | |
"# ================================================\n", | |
"# 2.1\n", | |
"# ================================================\n", | |
"\n", | |
"\n", | |
"# 2.1.5 サンプリングによる期待値の近似計算\n", | |
"# 例題:単純な確率分布のエントロピー計算\n", | |
"\n", | |
"from math import log\n", | |
"\n", | |
"\n", | |
"class MyDist(DiscreteProb):\n", | |
" def pmf(self, x: int) -> float:\n", | |
" assert x in [0, 1]\n", | |
" if x == 0:\n", | |
" return 2.0 / 3\n", | |
" else:\n", | |
" return 1.0 / 3\n", | |
"\n", | |
" def __call__(self) -> int:\n", | |
" if random.random() < 2.0 / 3:\n", | |
" return 0\n", | |
" else:\n", | |
" return 1\n", | |
"\n", | |
"\n", | |
"p = MyDist()\n", | |
"H = Entropy()\n", | |
"print(f\"近似値:{H(p):.8f}\")\n", | |
"print(f\"理論値:{-(1/3 * log(1/3) + 2/3 * log(2/3)):.8f}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0]\n", | |
"[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", | |
"E[Bern(x|0.5)] = 0.4968\n", | |
"E[Bern(x|0.9)] = 0.9048\n", | |
"H[Bern(x|0.5)] = 0.6931471805600546\n", | |
"H[Bern(x|0.9)] = 0.32662103059565345\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0, 0.5, 'H[Bern(x|μ)]')" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# ================================================\n", | |
"# 2.2 離散確率分布\n", | |
"# ================================================\n", | |
"\n", | |
"b1 = Bern(0.5)\n", | |
"print([b1() for _ in range(20)])\n", | |
"\n", | |
"b2 = Bern(0.9)\n", | |
"print([b2() for _ in range(20)])\n", | |
"\n", | |
"sampler = SimpleSampler()\n", | |
"print(f\"E[Bern(x|0.5)] = {sampler(b1, 10000)}\")\n", | |
"print(f\"E[Bern(x|0.9)] = {sampler(b2, 10000)}\")\n", | |
"\n", | |
"H = Entropy()\n", | |
"print(f\"H[Bern(x|0.5)] = {H(b1)}\")\n", | |
"print(f\"H[Bern(x|0.9)] = {H(b2)}\")\n", | |
"\n", | |
"from matplotlib import pyplot as plt\n", | |
"import numpy as np\n", | |
"\n", | |
"x = np.linspace(0, 1, 100)\n", | |
"y = [H(Bern(mu), 100000) for mu in x]\n", | |
"\n", | |
"fig, ax = plt.subplots()\n", | |
"ax.plot(x, y)\n", | |
"ax.set_ylim(0, 0.7)\n", | |
"ax.set_xlim(0, 1.0)\n", | |
"ax.set_xlabel(\"μ\")\n", | |
"ax.set_ylabel(\"H[Bern(x|μ)]\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment