-
-
Save tboggs/8778945 to your computer and use it in GitHub Desktop.
| '''Functions for drawing contours of Dirichlet distributions. | |
| MIT License | |
| Copyright (c) 2014 Thomas Boggs | |
| Permission is hereby granted, free of charge, to any person obtaining a copy | |
| of this software and associated documentation files (the "Software"), to deal | |
| in the Software without restriction, including without limitation the rights | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| copies of the Software, and to permit persons to whom the Software is | |
| furnished to do so, subject to the following conditions: | |
| The above copyright notice and this permission notice shall be included in all | |
| copies or substantial portions of the Software. | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| SOFTWARE. | |
| ''' | |
| from __future__ import division, print_function | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.tri as tri | |
| _corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]]) | |
| _AREA = 0.5 * 1 * 0.75**0.5 | |
| _triangle = tri.Triangulation(_corners[:, 0], _corners[:, 1]) | |
| # For each corner of the triangle, the pair of other corners | |
| _pairs = [_corners[np.roll(range(3), -i)[1:]] for i in range(3)] | |
| # The area of the triangle formed by point xy and another pair or points | |
| tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy))) | |
| def xy2bc(xy, tol=1.e-4): | |
| '''Converts 2D Cartesian coordinates to barycentric. | |
| Arguments: | |
| `xy`: A length-2 sequence containing the x and y value. | |
| ''' | |
| coords = np.array([tri_area(xy, p) for p in _pairs]) / _AREA | |
| return np.clip(coords, tol, 1.0 - tol) | |
| class Dirichlet(object): | |
| def __init__(self, alpha): | |
| '''Creates Dirichlet distribution with parameter `alpha`.''' | |
| from math import gamma | |
| from operator import mul | |
| self._alpha = np.array(alpha) | |
| self._coef = gamma(np.sum(self._alpha)) / \ | |
| np.multiply.reduce([gamma(a) for a in self._alpha]) | |
| def pdf(self, x): | |
| '''Returns pdf value for `x`.''' | |
| from operator import mul | |
| return self._coef * np.multiply.reduce([xx ** (aa - 1) | |
| for (xx, aa)in zip(x, self._alpha)]) | |
| def sample(self, N): | |
| '''Generates a random sample of size `N`.''' | |
| return np.random.dirichlet(self._alpha, N) | |
| def draw_pdf_contours(dist, border=False, nlevels=200, subdiv=8, **kwargs): | |
| '''Draws pdf contours over an equilateral triangle (2-simplex). | |
| Arguments: | |
| `dist`: A distribution instance with a `pdf` method. | |
| `border` (bool): If True, the simplex border is drawn. | |
| `nlevels` (int): Number of contours to draw. | |
| `subdiv` (int): Number of recursive mesh subdivisions to create. | |
| kwargs: Keyword args passed on to `plt.triplot`. | |
| ''' | |
| from matplotlib import ticker, cm | |
| import math | |
| refiner = tri.UniformTriRefiner(_triangle) | |
| trimesh = refiner.refine_triangulation(subdiv=subdiv) | |
| pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)] | |
| plt.tricontourf(trimesh, pvals, nlevels, cmap='jet', **kwargs) | |
| plt.axis('equal') | |
| plt.xlim(0, 1) | |
| plt.ylim(0, 0.75**0.5) | |
| plt.axis('off') | |
| if border is True: | |
| plt.triplot(_triangle, linewidth=1) | |
| def plot_points(X, barycentric=True, border=True, **kwargs): | |
| '''Plots a set of points in the simplex. | |
| Arguments: | |
| `X` (ndarray): A 2xN array (if in Cartesian coords) or 3xN array | |
| (if in barycentric coords) of points to plot. | |
| `barycentric` (bool): Indicates if `X` is in barycentric coords. | |
| `border` (bool): If True, the simplex border is drawn. | |
| kwargs: Keyword args passed on to `plt.plot`. | |
| ''' | |
| if barycentric is True: | |
| X = X.dot(_corners) | |
| plt.plot(X[:, 0], X[:, 1], 'k.', ms=1, **kwargs) | |
| plt.axis('equal') | |
| plt.xlim(0, 1) | |
| plt.ylim(0, 0.75**0.5) | |
| plt.axis('off') | |
| if border is True: | |
| plt.triplot(_triangle, linewidth=1) | |
| if __name__ == '__main__': | |
| f = plt.figure(figsize=(8, 6)) | |
| alphas = [[0.999] * 3, | |
| [5] * 3, | |
| [2, 5, 15]] | |
| for (i, alpha) in enumerate(alphas): | |
| plt.subplot(2, len(alphas), i + 1) | |
| dist = Dirichlet(alpha) | |
| draw_pdf_contours(dist) | |
| title = r'$\alpha$ = (%.3f, %.3f, %.3f)' % tuple(alpha) | |
| plt.title(title, fontdict={'fontsize': 8}) | |
| plt.subplot(2, len(alphas), i + 1 + len(alphas)) | |
| plot_points(dist.sample(5000)) | |
| plt.savefig('dirichlet_plots.png') | |
| print('Wrote plots to "dirichlet_plots.png".') |
Very strange. I have matplotlib 1.3.1 and I get this error:
Traceback (most recent call last):
File "simplex_plots.py", line 106, in <module>
draw_pdf_contours(dist)
File "simplex_plots.py", line 60, in draw_pdf_contours
refiner = tri.trirefine.UniformTriRefiner(_triangle)
AttributeError: 'module' object has no attribute 'trirefine'
This is a late reply because - for some reason - I never received notification of your comment. Did you edit the code before running? I ask because the line referenced in your error (line 60) doesn't match what is in my code. Your line is
refiner = tri.trirefine.UniformTriRefiner(_triangle)
which has an extra trirefine submodule referenced. The line in my code is just
refiner = tri.UniformTriRefiner(_triangle)
Hi,
For those using Python 3. You should do:
from functools import reduceThank you for creating the script and helping me build more intuition for the Dirichlet Distribution :-)
It came to my attention that the function xy2bc was incorrect, which resulted in varying inaccuracy over the simplex. While it didn't appear to make a difference for the tolerance used, I've updated this gist with a corrected implementation that uses fractional triangle areas to compute the barycentric coordinates. I also made some minor edits to account for python and matplotlib API changes since the original post.
Really useful - thanks!
Your code is very nice showing how to implement Dirichlet straight from the formula. I ve also tried to experiment it calling scipy.stats.dirichlet library instead. It worked well but we needed to change the tolerance of the xy2bc generator from 1e-4 to 1e-9. Otherwise the assertion of the library code _multivariate.py wont let us to run.
if (np.abs(np.sum(x, 0) - 1.0) > 10e-10).any():
raise ValueError("The input vector 'x' must lie within the normal "
"simplex. but np.sum(x, 0) = %s." % np.sum(x, 0))

Requires
matplotlib(v1.3 or greater).To generate the plot shown above, simply run: