Keywords: supervised learning | semi-supervised learning | unsupervised learning | mixture model | gaussian mixture model | pymc3 | label-switching | identifiability | normal distribution | pymc3 potentials | Download Notebook
Contents
- Mixture of 2 Gaussians, the old faithful data
- A tetchy 3 Gaussian Model
- Problems with clusters and sampling
- Full sampling is horrible, even with potentials
We now do a study of learning mixture models with MCMC. We have already done this in the case of the Zero-Inflated Poisson Model, and will stick to Gaussian Mixture models for now.
%matplotlib inline
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
pd.set_option('display.width', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.notebook_repr_html', True)
import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("poster")
import pymc3 as pm
import theano.tensor as tt
//anaconda/envs/py3l/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Mixture of 2 Gaussians, the old faithful data
We start by considering waiting times from the Old-Faithful Geyser at Yellowstone National Park.
ofdata=pd.read_csv("data/oldfaithful.csv")
ofdata.head()
eruptions | waiting | |
---|---|---|
0 | 3.600 | 79 |
1 | 1.800 | 54 |
2 | 3.333 | 74 |
3 | 2.283 | 62 |
4 | 4.533 | 85 |
sns.distplot(ofdata.waiting);
//anaconda/envs/py3l/lib/python3.6/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6521: MatplotlibDeprecationWarning:
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
alternative="'density'", removal="3.1")
Visually, there seem to be two components to the waiting time, so let us model this using a mixture of two gaussians. Remember that this is a unsupervized model, and all we are doing is modelling $p(x)$ , with the assumption that there are two clusters and a hidden variable $z$ that indexes them.
Notice that these gaussians seem well separated. The separation of gaussians impacts how your sampler will perform.
with pm.Model() as ofmodel:
p1 = pm.Uniform('p', 0, 1)
p2 = 1 - p1
p = tt.stack([p1, p2])
assignment = pm.Categorical("assignment", p,
shape=ofdata.shape[0])
sds = pm.Uniform("sds", 0, 40, shape=2)
centers = pm.Normal("centers",
mu=np.array([50, 80]),
sd=np.array([20, 20]),
shape=2)
# and to combine it with the observations:
observations = pm.Normal("obs", mu=centers[assignment], sd=sds[assignment], observed=ofdata.waiting)
with ofmodel:
#step1 = pm.Metropolis(vars=[p, sds, centers])
#step2 = pm.CategoricalGibbsMetropolis(vars=[assignment])
#oftrace_full = pm.sample(10000, step=[step1, step2])
oftrace_full = pm.sample(10000)
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [centers, sds, p]
>BinaryGibbsMetropolis: [assignment]
Sampling 2 chains: 100%|██████████| 21000/21000 [04:43<00:00, 73.96draws/s]
pm.model_to_graphviz(ofmodel)
oftrace = oftrace_full[2000::5]
pm.traceplot(oftrace);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning:
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
alternative='`bottom`', obj_type='argument')
pm.summary(oftrace)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
assignment__0 | 0.999687 | 0.017675 | 0.000311 | 1.000000 | 1.000000 | NaN | 1.000000 |
assignment__1 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__2 | 0.993125 | 0.082630 | 0.001504 | 1.000000 | 1.000000 | 3248.753445 | 0.999745 |
assignment__3 | 0.050937 | 0.219870 | 0.004231 | 0.000000 | 1.000000 | 3173.388070 | 1.000029 |
assignment__4 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__5 | 0.000313 | 0.017675 | 0.000311 | 0.000000 | 0.000000 | NaN | 1.000000 |
assignment__6 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__7 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__8 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__9 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__10 | 0.000625 | 0.024992 | 0.000437 | 0.000000 | 0.000000 | 3208.032629 | 0.999687 |
assignment__11 | 0.999687 | 0.017675 | 0.000311 | 1.000000 | 1.000000 | NaN | 1.000000 |
assignment__12 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__13 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__14 | 0.999687 | 0.017675 | 0.000311 | 1.000000 | 1.000000 | NaN | 1.000000 |
assignment__15 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__16 | 0.046875 | 0.211371 | 0.003711 | 0.000000 | 0.000000 | 3205.553814 | 0.999696 |
assignment__17 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__18 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__19 | 0.999375 | 0.024992 | 0.000437 | 1.000000 | 1.000000 | 3208.032629 | 0.999687 |
assignment__20 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__21 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__22 | 0.999062 | 0.030604 | 0.000533 | 1.000000 | 1.000000 | 3209.379383 | 0.999792 |
assignment__23 | 0.817500 | 0.386256 | 0.006151 | 0.000000 | 1.000000 | 3120.023664 | 0.999729 |
assignment__24 | 0.993125 | 0.082630 | 0.001629 | 1.000000 | 1.000000 | 3235.543269 | 1.001749 |
assignment__25 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__26 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__27 | 0.997188 | 0.052958 | 0.001177 | 1.000000 | 1.000000 | 3222.014229 | 0.999722 |
assignment__28 | 0.998437 | 0.039498 | 0.000681 | 1.000000 | 1.000000 | 3213.694327 | 0.999750 |
assignment__29 | 0.998750 | 0.035333 | 0.000612 | 1.000000 | 1.000000 | 3212.072926 | 0.999687 |
... | ... | ... | ... | ... | ... | ... | ... |
assignment__247 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__248 | 0.576562 | 0.494103 | 0.008319 | 0.000000 | 1.000000 | 2947.001300 | 1.000571 |
assignment__249 | 0.995938 | 0.063608 | 0.001140 | 1.000000 | 1.000000 | 2794.058393 | 0.999712 |
assignment__250 | 0.000625 | 0.024992 | 0.000437 | 0.000000 | 0.000000 | NaN | 1.000313 |
assignment__251 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__252 | 0.983750 | 0.126436 | 0.002048 | 1.000000 | 1.000000 | 3312.296435 | 0.999907 |
assignment__253 | 0.983125 | 0.128803 | 0.001897 | 1.000000 | 1.000000 | 3192.025887 | 0.999899 |
assignment__254 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__255 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__256 | 0.949375 | 0.219231 | 0.004141 | 0.000000 | 1.000000 | 2517.301618 | 1.000086 |
assignment__257 | 0.999687 | 0.017675 | 0.000311 | 1.000000 | 1.000000 | NaN | 1.000000 |
assignment__258 | 0.000625 | 0.024992 | 0.000437 | 0.000000 | 0.000000 | NaN | 1.000313 |
assignment__259 | 0.999687 | 0.017675 | 0.000311 | 1.000000 | 1.000000 | NaN | 1.000000 |
assignment__260 | 0.999687 | 0.017675 | 0.000311 | 1.000000 | 1.000000 | NaN | 1.000000 |
assignment__261 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__262 | 0.003438 | 0.058529 | 0.000978 | 0.000000 | 0.000000 | 3224.662207 | 0.999944 |
assignment__263 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__264 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__265 | 0.013750 | 0.116451 | 0.002088 | 0.000000 | 0.000000 | 3296.006530 | 0.999716 |
assignment__266 | 0.995938 | 0.063608 | 0.001140 | 1.000000 | 1.000000 | 2823.618545 | 1.000870 |
assignment__267 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__268 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__269 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | NaN | NaN |
assignment__270 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | NaN | NaN |
assignment__271 | 0.993125 | 0.082630 | 0.001295 | 1.000000 | 1.000000 | 3247.614839 | 0.999916 |
centers__0 | 54.627527 | 0.733327 | 0.011941 | 53.263593 | 56.069912 | 2857.383302 | 1.000063 |
centers__1 | 80.087834 | 0.523151 | 0.008877 | 79.104464 | 81.134251 | 3013.404106 | 0.999852 |
p | 0.361858 | 0.030777 | 0.000567 | 0.302382 | 0.419931 | 2753.669629 | 1.000464 |
sds__0 | 6.018165 | 0.585472 | 0.010211 | 4.961674 | 7.231991 | 2942.626870 | 0.999759 |
sds__1 | 5.947338 | 0.416616 | 0.008702 | 5.151874 | 6.773165 | 2566.297885 | 0.999898 |
277 rows × 7 columns
pm.autocorrplot(oftrace, varnames=['p', 'centers', 'sds']);
oftrace['centers'].mean(axis=0)
array([54.62752749, 80.0878344 ])
We can visualize the two clusters, suitably scales by the category-belonging probability by taking the posterior means. Note that this misses any smearing that might go into making the posterior predictive
from scipy.stats import norm
x = np.linspace(20, 120, 500)
# for pretty colors later in the book.
colors = ["#348ABD", "#A60628"] if oftrace['centers'][-1, 0] > oftrace['centers'][-1, 1] \
else ["#A60628", "#348ABD"]
posterior_center_means = oftrace['centers'].mean(axis=0)
posterior_std_means = oftrace['sds'].mean(axis=0)
posterior_p_mean = oftrace["p"].mean()
plt.hist(ofdata.waiting, bins=20, histtype="step", normed=True, color="k",
lw=2, label="histogram of data")
y = posterior_p_mean * norm.pdf(x, loc=posterior_center_means[0],
scale=posterior_std_means[0])
plt.plot(x, y, label="Cluster 0 (using posterior-mean parameters)", lw=3)
plt.fill_between(x, y, color=colors[1], alpha=0.3)
y = (1 - posterior_p_mean) * norm.pdf(x, loc=posterior_center_means[1],
scale=posterior_std_means[1])
plt.plot(x, y, label="Cluster 1 (using posterior-mean parameters)", lw=3)
plt.fill_between(x, y, color=colors[0], alpha=0.3)
plt.legend(loc="upper left")
plt.title("Visualizing Clusters using posterior-mean parameters");
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6521: MatplotlibDeprecationWarning:
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
alternative="'density'", removal="3.1")
A tetchy 3 Gaussian Model
Let us set up our data. Our analysis here follows that of https://colindcarroll.com/2018/07/20/why-im-excited-about-pymc3-v3.5.0/ , and we have chosen 3 gaussians reasonably close to each other to show the problems that arise!
mu_true = np.array([-2, 0, 2])
sigma_true = np.array([1, 1, 1])
lambda_true = np.array([1/3, 1/3, 1/3])
n = 100
from scipy.stats import multinomial
# Simulate from each distribution according to mixing proportion psi
z = multinomial.rvs(1, lambda_true, size=n)
data=np.array([np.random.normal(mu_true[i.astype('bool')][0], sigma_true[i.astype('bool')][0]) for i in z])
sns.distplot(data, bins=50);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6521: MatplotlibDeprecationWarning:
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
alternative="'density'", removal="3.1")
//anaconda/envs/py3l/lib/python3.6/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
np.savetxt("data/3gv2.dat", data)
with pm.Model() as mof:
#p = pm.Dirichlet('p', a=np.array([1., 1., 1.]), shape=3)
p=[1/3, 1/3, 1/3]
# cluster centers
means = pm.Normal('means', mu=0, sd=10, shape=3)
#sds = pm.HalfCauchy('sds', 5, shape=3)
sds = np.array([1., 1., 1.])
# latent cluster of each observation
category = pm.Categorical('category',
p=p,
shape=data.shape[0])
# likelihood for each observed value
points = pm.Normal('obs',
mu=means[category],
sd=1., #sds[category],
observed=data)
with mof:
tripletrace_full = pm.sample(10000)
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [means]
>CategoricalGibbsMetropolis: [category]
Sampling 2 chains: 100%|██████████| 21000/21000 [01:39<00:00, 210.02draws/s]
The estimated number of effective samples is smaller than 200 for some parameters.
trace_mof=tripletrace_full[3000::7]
#pm.traceplot(trace_mof, varnames=["means", "p", "sds"]);
pm.traceplot(trace_mof, varnames=["means"], combined=True);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning:
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
alternative='`bottom`', obj_type='argument')
pm.autocorrplot(trace_mof, varnames=['means']);
Problems with clusters and sampling
Some of the traces seem ok, but the autocorrelation is quite bad. And there is label-switching .This is because there are major problems with using MCMC for clustering.
AND THIS IS WITHOUT MODELING $p$ OR $\sigma$. It gets much worse otherwise! (it would be better if the gaussians were quite widely separated out).
These are firstly, the lack of parameter identifiability (the so called label-switching problem) and secondly, the multimodality of the posteriors.
We have seen non-identifiability before. Switching labels on the means and z’s, for example, does not change the likelihoods. The problem with this is that cluster parameters cannot be compared across chains: what might be a cluster parameter in one chain could well belong to the other cluster in the second chain. Even within a single chain, indices might swap leading to a telltale back and forth in the traces for long chains or data not cleanly separated.
Also, the (joint) posteriors can be highly multimodal. One form of multimodality is the non-identifiability, though even without identifiability issues the posteriors are highly multimodal.
To quote the Stan manual:
Bayesian inference fails in cases of high multimodality because there is no way to visit all of the modes in the posterior in appropriate proportions and thus no way to evaluate integrals involved in posterior predictive inference. In light of these two problems, the advice often given in fitting clustering models is to try many different initializations and select the sample with the highest overall probability. It is also popular to use optimization-based point estimators such as expectation maximization or variational Bayes, which can be much more efficient than sampling-based approaches.
Some mitigation via ordering in pymc3
But this is not a panacea. Sampling is still very hard.
import theano.tensor as tt
import pymc3.distributions.transforms as tr
with pm.Model() as mof2:
p = [1/3, 1/3, 1/3]
# cluster centers
means = pm.Normal('means', mu=0, sd=10, shape=3,
transform=tr.ordered,
testval=np.array([-1, 0, 1]))
# measurement error
#sds = pm.Uniform('sds', lower=0, upper=20, shape=3)
# latent cluster of each observation
category = pm.Categorical('category',
p=p,
shape=data.shape[0])
# likelihood for each observed value
points = pm.Normal('obs',
mu=means[category],
sd=1., #sds[category],
observed=data)
with mof2:
tripletrace_full2 = pm.sample(10000)
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [means]
>CategoricalGibbsMetropolis: [category]
Sampling 2 chains: 100%|██████████| 21000/21000 [03:42<00:00, 94.44draws/s]
The acceptance probability does not match the target. It is 0.9255675538672136, but should be close to 0.8. Try to increase the number of tuning steps.
The number of effective samples is smaller than 10% for some parameters.
trace_mof2 = tripletrace_full2[3000::5]
pm.traceplot(trace_mof2, varnames=["means"], combined=True);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning:
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
alternative='`bottom`', obj_type='argument')
pm.autocorrplot(trace_mof2, varnames=["means"]);
Full sampling is horrible, even with potentials
Now lets put Dirichlet based (and this is a strongly centering Dirichlet) prior on the probabilities
from scipy.stats import dirichlet
ds = dirichlet(alpha=[10,10,10]).rvs(1000)
"""
Visualize points on the 3-simplex (eg, the parameters of a
3-dimensional multinomial distributions) as a scatter plot
contained within a 2D triangle.
David Andrzejewski (david.andrzej@gmail.com)
"""
import numpy as NP
import matplotlib.pyplot as P
import matplotlib.ticker as MT
import matplotlib.lines as L
import matplotlib.cm as CM
import matplotlib.colors as C
import matplotlib.patches as PA
def plotSimplex(points, fig=None,
vertexlabels=['1','2','3'],
**kwargs):
"""
Plot Nx3 points array on the 3-simplex
(with optionally labeled vertices)
kwargs will be passed along directly to matplotlib.pyplot.scatter
Returns Figure, caller must .show()
"""
if(fig == None):
fig = P.figure()
# Draw the triangle
l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
[0, NP.sqrt(3) / 2, 0, 0], # ycoords
color='k')
fig.gca().add_line(l1)
fig.gca().xaxis.set_major_locator(MT.NullLocator())
fig.gca().yaxis.set_major_locator(MT.NullLocator())
# Draw vertex labels
fig.gca().text(-0.05, -0.05, vertexlabels[0])
fig.gca().text(1.05, -0.05, vertexlabels[1])
fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2])
# Project and draw the actual points
projected = projectSimplex(points)
P.scatter(projected[:,0], projected[:,1], **kwargs)
# Leave some buffer around the triangle for vertex labels
fig.gca().set_xlim(-0.2, 1.2)
fig.gca().set_ylim(-0.2, 1.2)
return fig
def projectSimplex(points):
"""
Project probabilities on the 3-simplex to a 2D triangle
N points are given as N x 3 array
"""
# Convert points one at a time
tripts = NP.zeros((points.shape[0],2))
for idx in range(points.shape[0]):
# Init to triangle centroid
x = 1.0 / 2
y = 1.0 / (2 * NP.sqrt(3))
# Vector 1 - bisect out of lower left vertex
p1 = points[idx, 0]
x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6)
y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6)
# Vector 2 - bisect out of lower right vertex
p2 = points[idx, 1]
x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6)
y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6)
# Vector 3 - bisect out of top vertex
p3 = points[idx, 2]
y = y + (1.0 / NP.sqrt(3) * p3)
tripts[idx,:] = (x,y)
return tripts
plotSimplex(ds, s=20);
The idea behind a Potential
is something that is not part of the likelihood, but enforces a constraint by setting the probability to 0 if the constraint is violated. We use it here to give each cluster some membership and to order the means to remove the non-identifiability problem. See below for how its used.
The sampler below has a lot of problems.
with pm.Model() as mofb:
p = pm.Dirichlet('p', a=np.array([10., 10., 10.]), shape=3)
# ensure all clusters have some points
p_min_potential = pm.Potential('p_min_potential', tt.switch(tt.min(p) < .1, -np.inf, 0))
# cluster centers
means = pm.Normal('means', mu=0, sd=10, shape=3, transform=tr.ordered,
testval=np.array([-1, 0, 1]))
category = pm.Categorical('category',
p=p,
shape=data.shape[0])
# likelihood for each observed value
points = pm.Normal('obs',
mu=means[category],
sd=1., #sds[category],
observed=data)
with mofb:
tripletrace_fullb = pm.sample(10000, nuts_kwargs=dict(target_accept=0.95))
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [means, p]
>CategoricalGibbsMetropolis: [category]
Sampling 2 chains: 100%|██████████| 21000/21000 [06:13<00:00, 56.23draws/s]
There were 10 divergences after tuning. Increase `target_accept` or reparameterize.
There were 7 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.
trace_mofb = tripletrace_fullb[3000::5]
pm.traceplot(trace_mofb, varnames=["means", "p"], combined=True);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning:
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
alternative='`bottom`', obj_type='argument')
pm.summary(trace_mofb)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
category__0 | 1.406071 | 0.575480 | 0.014954 | 1.000000 | 2.000000 | 1362.583316 | 1.000115 |
category__1 | 0.228571 | 0.421610 | 0.010938 | 0.000000 | 1.000000 | 1306.422985 | 1.000472 |
category__2 | 0.395357 | 0.496897 | 0.011826 | 0.000000 | 1.000000 | 1674.615993 | 1.000276 |
category__3 | 0.108929 | 0.311550 | 0.008695 | 0.000000 | 1.000000 | 1307.190266 | 1.001074 |
category__4 | 0.824643 | 0.552559 | 0.011552 | 0.000000 | 2.000000 | 2328.351050 | 0.999864 |
category__5 | 0.280714 | 0.452516 | 0.011192 | 0.000000 | 1.000000 | 1662.033357 | 1.000363 |
category__6 | 1.948929 | 0.221760 | 0.004981 | 1.000000 | 2.000000 | 1702.198664 | 1.002138 |
category__7 | 1.712143 | 0.473583 | 0.012564 | 1.000000 | 2.000000 | 1256.089935 | 0.999754 |
category__8 | 0.900357 | 0.569837 | 0.011499 | 0.000000 | 2.000000 | 2245.021084 | 0.999675 |
category__9 | 1.739286 | 0.454218 | 0.012043 | 1.000000 | 2.000000 | 1421.562668 | 1.000357 |
category__10 | 0.432500 | 0.506121 | 0.013448 | 0.000000 | 1.000000 | 1749.523067 | 0.999755 |
category__11 | 1.758214 | 0.445337 | 0.011746 | 1.000000 | 2.000000 | 1291.113963 | 1.000946 |
category__12 | 1.916071 | 0.282386 | 0.006534 | 1.000000 | 2.000000 | 1765.304573 | 1.000003 |
category__13 | 1.866429 | 0.342284 | 0.009551 | 1.000000 | 2.000000 | 1477.039629 | 0.999995 |
category__14 | 0.225000 | 0.419289 | 0.010990 | 0.000000 | 1.000000 | 1472.910821 | 1.001048 |
category__15 | 0.459286 | 0.513864 | 0.011694 | 0.000000 | 1.000000 | 1619.209679 | 1.000851 |
category__16 | 0.649643 | 0.528238 | 0.011004 | 0.000000 | 1.000000 | 2333.805817 | 0.999647 |
category__17 | 0.440000 | 0.511273 | 0.012340 | 0.000000 | 1.000000 | 1803.682268 | 1.000142 |
category__18 | 1.989286 | 0.102954 | 0.002396 | 2.000000 | 2.000000 | 2055.892030 | 0.999691 |
category__19 | 1.957143 | 0.202535 | 0.004574 | 2.000000 | 2.000000 | 1996.770513 | 1.001435 |
category__20 | 0.445714 | 0.510514 | 0.013029 | 0.000000 | 1.000000 | 1827.687813 | 0.999651 |
category__21 | 1.989643 | 0.101242 | 0.002039 | 2.000000 | 2.000000 | 2505.025728 | 0.999954 |
category__22 | 1.930000 | 0.255147 | 0.006387 | 1.000000 | 2.000000 | 1414.514214 | 1.000427 |
category__23 | 1.332143 | 0.587704 | 0.014147 | 0.000000 | 2.000000 | 1797.023673 | 1.001351 |
category__24 | 1.733571 | 0.457963 | 0.012479 | 1.000000 | 2.000000 | 1184.444494 | 1.000431 |
category__25 | 1.515714 | 0.556555 | 0.015364 | 1.000000 | 2.000000 | 1331.826974 | 0.999807 |
category__26 | 0.425000 | 0.500089 | 0.012622 | 0.000000 | 1.000000 | 1455.116136 | 0.999694 |
category__27 | 1.979643 | 0.141219 | 0.002913 | 2.000000 | 2.000000 | 2378.339776 | 1.003031 |
category__28 | 1.722143 | 0.468214 | 0.012763 | 1.000000 | 2.000000 | 1469.880006 | 1.000669 |
category__29 | 0.873214 | 0.554202 | 0.011719 | 0.000000 | 2.000000 | 2572.986079 | 1.000211 |
... | ... | ... | ... | ... | ... | ... | ... |
category__76 | 0.473929 | 0.516892 | 0.012308 | 0.000000 | 1.000000 | 1568.551705 | 1.000044 |
category__77 | 1.940714 | 0.237666 | 0.005468 | 1.000000 | 2.000000 | 2057.523309 | 1.000546 |
category__78 | 1.552857 | 0.545821 | 0.014423 | 1.000000 | 2.000000 | 1278.184527 | 1.001185 |
category__79 | 1.197143 | 0.603317 | 0.012732 | 0.000000 | 2.000000 | 2419.356924 | 0.999812 |
category__80 | 0.119643 | 0.325642 | 0.008502 | 0.000000 | 1.000000 | 1457.209338 | 0.999644 |
category__81 | 0.671429 | 0.535762 | 0.011043 | 0.000000 | 1.000000 | 2416.678137 | 0.999991 |
category__82 | 1.471786 | 0.560539 | 0.014711 | 1.000000 | 2.000000 | 1408.654254 | 0.999643 |
category__83 | 0.687857 | 0.540894 | 0.011061 | 0.000000 | 1.000000 | 2184.065312 | 1.000089 |
category__84 | 1.832500 | 0.379117 | 0.009391 | 1.000000 | 2.000000 | 1442.318862 | 0.999651 |
category__85 | 0.238929 | 0.428101 | 0.010072 | 0.000000 | 1.000000 | 1580.222999 | 0.999649 |
category__86 | 1.001786 | 0.574764 | 0.011643 | 0.000000 | 2.000000 | 2431.781980 | 0.999884 |
category__87 | 0.714286 | 0.546884 | 0.011661 | 0.000000 | 1.000000 | 2298.065679 | 0.999726 |
category__88 | 1.627143 | 0.519869 | 0.014501 | 1.000000 | 2.000000 | 1153.871436 | 0.999673 |
category__89 | 1.979286 | 0.142426 | 0.002728 | 2.000000 | 2.000000 | 2325.559161 | 0.999668 |
category__90 | 1.987143 | 0.112658 | 0.002501 | 2.000000 | 2.000000 | 1955.167689 | 1.000005 |
category__91 | 1.411429 | 0.573845 | 0.012387 | 1.000000 | 2.000000 | 2121.534066 | 0.999742 |
category__92 | 0.287857 | 0.454339 | 0.011087 | 0.000000 | 1.000000 | 1656.472820 | 1.000951 |
category__93 | 0.120714 | 0.325795 | 0.009032 | 0.000000 | 1.000000 | 1414.023539 | 0.999648 |
category__94 | 0.817857 | 0.560327 | 0.011872 | 0.000000 | 2.000000 | 1951.110342 | 0.999649 |
category__95 | 1.003571 | 0.596407 | 0.012872 | 0.000000 | 2.000000 | 2002.102565 | 0.999679 |
category__96 | 0.076786 | 0.266251 | 0.007683 | 0.000000 | 1.000000 | 1365.135665 | 0.999645 |
category__97 | 1.140000 | 0.584661 | 0.010782 | 0.000000 | 2.000000 | 2331.457873 | 0.999697 |
category__98 | 1.993571 | 0.079920 | 0.001462 | 2.000000 | 2.000000 | 2601.475216 | 1.000921 |
category__99 | 0.199643 | 0.400624 | 0.010216 | 0.000000 | 1.000000 | 1420.458431 | 1.000063 |
p__0 | 0.360226 | 0.081588 | 0.003317 | 0.204306 | 0.519089 | 519.487415 | 1.002060 |
p__1 | 0.325225 | 0.071761 | 0.001919 | 0.172866 | 0.459570 | 1294.128430 | 0.999978 |
p__2 | 0.314550 | 0.065550 | 0.002548 | 0.189302 | 0.444908 | 587.062612 | 1.001335 |
means__0 | -1.886959 | 0.263363 | 0.007660 | -2.388613 | -1.365366 | 1295.486095 | 1.002200 |
means__1 | -0.407361 | 0.565056 | 0.029250 | -1.451171 | 0.683029 | 286.606094 | 1.002899 |
means__2 | 1.948775 | 0.303411 | 0.009837 | 1.418676 | 2.599857 | 791.034984 | 1.001906 |
106 rows × 7 columns
Making Problems go away
A lot will go away when identifiability improves through separated gaussians. But that changes the data. If we want any further improvement on this data, we are going to have to stop sampling so many discrete categoricals. And for that we will need a marginalization trick.