Last active
February 16, 2023 23:59
-
-
Save iancze/94bdb59f83df51102f8a238b248611c9 to your computer and use it in GitHub Desktop.
Pyro: using latent variables with a plate
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import numpy as np | |
import pyro | |
import pyro.distributions as dist | |
from torch.distributions import constraints | |
from pyro.nn import PyroModule, PyroParam, PyroSample | |
from pyro.infer import Predictive, MCMC, NUTS | |
from pyro.infer.autoguide import AutoDiagonalNormal | |
from pyro.infer import SVI, Trace_ELBO, Predictive | |
def gaussian(x, A_i, mu_i, sigma_i=0.5): | |
r""" | |
Evaluate a Gaussian ring of the form | |
.. math:: | |
f(r) = A_i \exp \left(- \frac{(r - r_i)^2}{2 \sigma_i^2} \right) | |
""" | |
return A_i * torch.exp(-((x - mu_i) ** 2) / (2 * sigma_i**2)) | |
class Spectrum(PyroModule): | |
def __init__(self, mus): | |
super().__init__() | |
self.mus = torch.as_tensor(mus) | |
self.nlines = len(self.mus) | |
# works fine | |
# self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2).expand([self.nlines]).to_event(1)) | |
with pyro.plate("plate", self.nlines): | |
# doesn't work, errors with tensor shape | |
# self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2)) | |
# code executes but can't find "log_amplitudes" in samples | |
self.log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2)) | |
self.baseline = PyroSample(dist.Normal(0.0, 1.0)) | |
def intensities(self, x): | |
I = torch.zeros_like(x) | |
for i in range(self.nlines): | |
A_i = torch.pow(10.0, self.log_amplitudes[i]) | |
mu_i = self.mus[i] | |
I += gaussian(x, A_i, mu_i) | |
return I | |
def forward(self, x, y, yerr): | |
I = self.intensities(x) + self.baseline | |
with pyro.plate("data", len(y)): | |
pyro.sample("obs", dist.Normal(I, yerr), obs=y) | |
return I | |
if __name__=="__main__": | |
import numpy as np | |
np.random.seed(123) | |
# create a fake dataset | |
N = 80 | |
xs = torch.as_tensor(np.sort(np.random.uniform(0, 10, size=N))) | |
true_mus = np.array([2.0, 4.5, 7.4]) | |
true_amplitudes = np.array([0.4, 1.0, 0.6]) | |
true_log_amplitudes = np.log10(true_amplitudes) | |
true_baseline = 0.5 | |
yerr = 0.05 | |
ys = torch.zeros_like(xs) | |
for i in range(len(true_amplitudes)): | |
ys += gaussian(xs, true_amplitudes[i], true_mus[i]) | |
ys += true_baseline | |
# add random noise | |
ys += torch.as_tensor(np.random.normal(loc=0, scale=yerr, size=N)) | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots(nrows=1) | |
ax.plot(xs.numpy(), ys.numpy(), "o") | |
fig.savefig("data.png") | |
# now create a model | |
model = Spectrum(mus=true_mus) | |
# define SVI guide | |
guide = AutoDiagonalNormal(model) | |
adam = pyro.optim.Adam({"lr": 0.03}) | |
svi = SVI(model, guide, adam, loss=Trace_ELBO()) | |
num_iterations = 10 | |
pyro.clear_param_store() | |
loss_tracker = np.empty(num_iterations) | |
for j in range(num_iterations): | |
# calculate the loss and take a gradient step | |
loss_tracker[j] = svi.step(xs, ys, yerr) | |
print(j) | |
predictive = Predictive(model, guide=guide, num_samples=1)(xs, ys, yerr) | |
for k, v in predictive.items(): | |
print(f"{k}: {v.shape}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import numpy as np | |
import pyro | |
import pyro.distributions as dist | |
from torch.distributions import constraints | |
from pyro.nn import PyroModule, PyroParam, PyroSample | |
from pyro.infer import Predictive, MCMC, NUTS | |
from pyro.infer.autoguide import AutoNormal | |
from pyro.infer import SVI, Trace_ELBO, Predictive | |
import numpy as np | |
def gaussian(x, A_i, mu_i, sigma_i=0.5): | |
r""" | |
Evaluate a Gaussian ring of the form | |
.. math:: | |
f(r) = A_i \exp \left(- \frac{(r - r_i)^2}{2 \sigma_i^2} \right) | |
""" | |
return A_i * torch.exp(-((x - mu_i) ** 2) / (2 * sigma_i**2)) | |
np.random.seed(123) | |
# create a fake dataset | |
N = 80 | |
xs = torch.as_tensor(np.sort(np.random.uniform(0, 10, size=N))) | |
true_mus = np.array([2.0, 4.5, 7.4]) | |
true_amplitudes = np.array([0.4, 1.0, 0.6]) | |
true_log_amplitudes = np.log10(true_amplitudes) | |
true_baseline = 0.5 | |
yerr = 0.05 | |
ys = torch.zeros_like(xs) | |
for i in range(len(true_amplitudes)): | |
ys += gaussian(xs, true_amplitudes[i], true_mus[i]) | |
ys += true_baseline | |
# add random noise | |
ys += torch.as_tensor(np.random.normal(loc=0, scale=yerr, size=N)) | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots(nrows=1) | |
ax.plot(xs.numpy(), ys.numpy(), "o") | |
fig.savefig("data.png") | |
# define the model | |
def model_func(x, y, yerr): | |
baseline = pyro.sample("baseline", dist.Normal(0.0, 1.0)) | |
with pyro.plate("plate", 3): | |
log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2)) | |
I = torch.zeros_like(x) | |
for i in range(3): | |
A_i = torch.pow(10.0, log_amplitudes[i]) | |
mu_i = true_mus[i] | |
I += gaussian(x, A_i, mu_i) | |
I += baseline | |
with pyro.plate("data", len(y)): | |
pyro.sample("obs", dist.Normal(I, yerr), obs=y) | |
# define SVI guide | |
guide = AutoNormal(model_func) | |
adam = pyro.optim.Adam({"lr": 0.03}) | |
svi = SVI(model_func, guide, adam, loss=Trace_ELBO()) | |
num_iterations = 1000 | |
pyro.clear_param_store() | |
loss_tracker = np.empty(num_iterations) | |
for j in range(num_iterations): | |
# calculate the loss and take a gradient step | |
loss_tracker[j] = svi.step(xs, ys, yerr) | |
print(j) | |
predictive = Predictive(model_func, guide=guide, num_samples=1)(xs, ys, yerr) | |
for k, v in predictive.items(): | |
print(f"{k}: {v.shape}") | |
# https://forum.pyro.ai/t/how-to-access-guide-parameters/3995 | |
print(list(guide.parameters())) | |
with pyro.poutine.trace(param_only=True) as tr: | |
guide(xs, ys, yerr) | |
constrained_params = [site["value"] for site in tr.trace.nodes.values()] | |
PARAMS = [p.unconstrained() for p in constrained_params] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment