Created
July 14, 2020 20:20
-
-
Save DanielWeitzenfeld/61d797df025dbb632d5467483c049426 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import pandas as pd\n", | |
"import pymc3 as pm\n", | |
"from pymc3.distributions.dist_math import normal_lccdf\n", | |
"import numpy as np\n", | |
"import theano.tensor as T" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"https://discourse.pymc.io/t/fit-interval-as-a-model-parameter/5453" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def weird_function(x, gate1, gate2, A, B):\n", | |
" return (\n", | |
" ((gate1 < x) & (x < gate2)) * (A * (x - gate1) + B)\n", | |
" + (x < gate1) * B * (1 - np.sin((gate1 - x) * 3) / 10)\n", | |
" + (x > gate2) * (B + (gate2 - gate1) * A + A / 5 * np.sin((x - gate2) * 10) - (x - gate2) ** 2 / 4)\n", | |
" )\n", | |
"\n", | |
"A = 2\n", | |
"B = 3\n", | |
"gate1 = 3\n", | |
"gate2 = 7\n", | |
"xdata = np.linspace(0, 10, 50)\n", | |
"ydata = weird_function(xdata, gate1, gate2, A, B)\n", | |
"yerror = np.random.normal(scale=.2, size=len(ydata))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.collections.PathCollection at 0x117d29cc0>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"image/png": { | |
"height": 248, | |
"width": 369 | |
} | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"fig, ax = plt.subplots()\n", | |
"ax.scatter(xdata, ydata + yerror)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class BoundedLine(pm.Continuous):\n", | |
" def __init__(self, A, B, sigma, lower, upper, xdata, *args, **kwargs):\n", | |
" super(BoundedLine, self).__init__(*args, **kwargs)\n", | |
" self.A = A = T.as_tensor_variable(A)\n", | |
" self.B = B = T.as_tensor_variable(B)\n", | |
" self.sigma = sigma = T.as_tensor_variable(sigma)\n", | |
" self.lower = lower = T.as_tensor_variable(lower)\n", | |
" self.upper = upper = T.as_tensor_variable(upper)\n", | |
" self.xdata = xdata = T.as_tensor_variable(xdata)\n", | |
"\n", | |
" def logp(self, x):\n", | |
" A = self.A\n", | |
" B = self.B\n", | |
" sigma = self.sigma\n", | |
" lower = self.lower\n", | |
" upper = self.upper\n", | |
" xdata = self.xdata\n", | |
" n = pm.Normal.dist(mu=A * (xdata-lower) + B, sd=sigma)\n", | |
" n2 = pm.Normal.dist(mu=5, sd=30)\n", | |
" \n", | |
" return T.sum(T.switch(T.and_(T.ge(xdata, lower), \n", | |
" T.le(xdata, upper)),\n", | |
" n.logp(x),\n", | |
" n2.logp(x)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Auto-assigning NUTS sampler...\n", | |
"Initializing NUTS using jitter+adapt_diag...\n", | |
"Multiprocess sampling (4 chains in 4 jobs)\n", | |
"NUTS: [upper, lower, sigma, B, A]\n", | |
"Sampling 4 chains, 0 divergences: 100%|██████████| 6000/6000 [04:43<00:00, 9.33draws/s]\n", | |
"The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.\n", | |
"The acceptance probability does not match the target. It is 0.8913020215112039, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
"The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.\n", | |
"The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.\n", | |
"The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.\n", | |
"The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.\n", | |
"The estimated number of effective samples is smaller than 200 for some parameters.\n" | |
] | |
} | |
], | |
"source": [ | |
"with pm.Model() as model:\n", | |
" A = pm.Normal('A', mu=2, sigma=5)\n", | |
" B = pm.Normal('B', mu=3, sigma=5)\n", | |
"\n", | |
" sigma = pm.HalfNormal('sigma', .2)\n", | |
" \n", | |
" lower = pm.Normal('lower', mu=3, sd=1, testval=3)\n", | |
" upper = pm.Normal('upper', mu=7, sd=1, testval=7)\n", | |
" \n", | |
" y = BoundedLine('obs', A=A, B=B, \n", | |
" sigma=sigma, \n", | |
" lower=lower, upper=upper, xdata=xdata,\n", | |
" observed=ydata + yerror)\n", | |
"\n", | |
" trace = pm.sample(draws=1000, tune=500)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment