Last active
April 27, 2022 03:27
-
-
Save yknishidate/75b6c25acec759413342668def475bf5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import random | |
import matplotlib.pyplot as plt | |
class UniformDistribution: | |
def __init__(self, a: float, b: float) -> None: | |
self.a = a | |
self.b = b | |
def sample(self) -> float: | |
return random.uniform(self.a, self.b) | |
def pdf(self, x: float) -> float: | |
return 1.0 / (self.b - self.a) | |
class SinDistribution: | |
def sample(self) -> None: | |
assert False, 'cannot sample from this distribution!' | |
def pdf(self, x: float) -> float: | |
""" return the probability of x """ | |
return math.sin(x) | |
class Reservoir: | |
def __init__(self) -> None: | |
self.y = None | |
self.weight_sum = 0.0 | |
def update(self, x, w) -> None: | |
self.weight_sum += w | |
if random.random() < w / self.weight_sum: | |
self.y = x | |
def reservoir_sampling(num_candidates) -> float: | |
reservoir = Reservoir() | |
source = UniformDistribution(0.0, math.pi) | |
target = SinDistribution() | |
for _ in range(num_candidates): | |
x = source.sample() | |
w = target.pdf(x) / source.pdf(x) | |
reservoir.update(x, w) | |
return reservoir.y | |
if __name__ == "__main__": | |
num_candidates = 32 | |
samples = [reservoir_sampling(num_candidates) for _ in range(50000)] | |
plt.hist(samples, bins=20) | |
plt.title("Weighted Reservoir Sampling, M=" + str(num_candidates)) | |
plt.show() | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
WRSを使ってSin分布からサンプリングしてみた
