Keywords: supervised learning |  semi-supervised learning |  unsupervised learning |  mixture model |  gaussian mixture model |  pymc3 |  label-switching |  identifiability |  normal distribution |  pymc3 potentials |  Download Notebook

Contents

We now do a study of learning mixture models with MCMC. We have already done this in the case of the Zero-Inflated Poisson Model, and will stick to Gaussian Mixture models for now.

%matplotlib inline
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
pd.set_option('display.width', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.notebook_repr_html', True)
import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("poster")
import pymc3 as pm
import theano.tensor as tt
//anaconda/envs/py3l/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Mixture of 2 Gaussians, the old faithful data

We start by considering waiting times from the Old-Faithful Geyser at Yellowstone National Park.

ofdata=pd.read_csv("data/oldfaithful.csv")
ofdata.head()
eruptions waiting
0 3.600 79
1 1.800 54
2 3.333 74
3 2.283 62
4 4.533 85
sns.distplot(ofdata.waiting);
//anaconda/envs/py3l/lib/python3.6/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6521: MatplotlibDeprecationWarning: 
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
  alternative="'density'", removal="3.1")

png

Visually, there seem to be two components to the waiting time, so let us model this using a mixture of two gaussians. Remember that this is a unsupervized model, and all we are doing is modelling $p(x)$ , with the assumption that there are two clusters and a hidden variable $z$ that indexes them.

Notice that these gaussians seem well separated. The separation of gaussians impacts how your sampler will perform.

with pm.Model() as ofmodel:
    p1 = pm.Uniform('p', 0, 1)
    p2 = 1 - p1
    p = tt.stack([p1, p2])
    assignment = pm.Categorical("assignment", p, 
                                shape=ofdata.shape[0])
    sds = pm.Uniform("sds", 0, 40, shape=2)
    centers = pm.Normal("centers", 
                        mu=np.array([50, 80]), 
                        sd=np.array([20, 20]), 
                        shape=2)
    
    # and to combine it with the observations:
    observations = pm.Normal("obs", mu=centers[assignment], sd=sds[assignment], observed=ofdata.waiting)
with ofmodel:
    #step1 = pm.Metropolis(vars=[p, sds, centers])
    #step2 = pm.CategoricalGibbsMetropolis(vars=[assignment])
    #oftrace_full = pm.sample(10000, step=[step1, step2])
    oftrace_full = pm.sample(10000)
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [centers, sds, p]
>BinaryGibbsMetropolis: [assignment]
Sampling 2 chains: 100%|██████████| 21000/21000 [04:43<00:00, 73.96draws/s]
pm.model_to_graphviz(ofmodel)

svg

oftrace = oftrace_full[2000::5]
pm.traceplot(oftrace);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning: 
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
  alternative='`bottom`', obj_type='argument')

png

pm.summary(oftrace)
mean sd mc_error hpd_2.5 hpd_97.5 n_eff Rhat
assignment__0 0.999687 0.017675 0.000311 1.000000 1.000000 NaN 1.000000
assignment__1 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__2 0.993125 0.082630 0.001504 1.000000 1.000000 3248.753445 0.999745
assignment__3 0.050937 0.219870 0.004231 0.000000 1.000000 3173.388070 1.000029
assignment__4 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__5 0.000313 0.017675 0.000311 0.000000 0.000000 NaN 1.000000
assignment__6 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__7 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__8 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__9 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__10 0.000625 0.024992 0.000437 0.000000 0.000000 3208.032629 0.999687
assignment__11 0.999687 0.017675 0.000311 1.000000 1.000000 NaN 1.000000
assignment__12 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__13 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__14 0.999687 0.017675 0.000311 1.000000 1.000000 NaN 1.000000
assignment__15 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__16 0.046875 0.211371 0.003711 0.000000 0.000000 3205.553814 0.999696
assignment__17 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__18 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__19 0.999375 0.024992 0.000437 1.000000 1.000000 3208.032629 0.999687
assignment__20 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__21 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__22 0.999062 0.030604 0.000533 1.000000 1.000000 3209.379383 0.999792
assignment__23 0.817500 0.386256 0.006151 0.000000 1.000000 3120.023664 0.999729
assignment__24 0.993125 0.082630 0.001629 1.000000 1.000000 3235.543269 1.001749
assignment__25 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__26 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__27 0.997188 0.052958 0.001177 1.000000 1.000000 3222.014229 0.999722
assignment__28 0.998437 0.039498 0.000681 1.000000 1.000000 3213.694327 0.999750
assignment__29 0.998750 0.035333 0.000612 1.000000 1.000000 3212.072926 0.999687
... ... ... ... ... ... ... ...
assignment__247 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__248 0.576562 0.494103 0.008319 0.000000 1.000000 2947.001300 1.000571
assignment__249 0.995938 0.063608 0.001140 1.000000 1.000000 2794.058393 0.999712
assignment__250 0.000625 0.024992 0.000437 0.000000 0.000000 NaN 1.000313
assignment__251 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__252 0.983750 0.126436 0.002048 1.000000 1.000000 3312.296435 0.999907
assignment__253 0.983125 0.128803 0.001897 1.000000 1.000000 3192.025887 0.999899
assignment__254 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__255 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__256 0.949375 0.219231 0.004141 0.000000 1.000000 2517.301618 1.000086
assignment__257 0.999687 0.017675 0.000311 1.000000 1.000000 NaN 1.000000
assignment__258 0.000625 0.024992 0.000437 0.000000 0.000000 NaN 1.000313
assignment__259 0.999687 0.017675 0.000311 1.000000 1.000000 NaN 1.000000
assignment__260 0.999687 0.017675 0.000311 1.000000 1.000000 NaN 1.000000
assignment__261 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__262 0.003438 0.058529 0.000978 0.000000 0.000000 3224.662207 0.999944
assignment__263 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__264 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__265 0.013750 0.116451 0.002088 0.000000 0.000000 3296.006530 0.999716
assignment__266 0.995938 0.063608 0.001140 1.000000 1.000000 2823.618545 1.000870
assignment__267 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__268 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__269 1.000000 0.000000 0.000000 1.000000 1.000000 NaN NaN
assignment__270 0.000000 0.000000 0.000000 0.000000 0.000000 NaN NaN
assignment__271 0.993125 0.082630 0.001295 1.000000 1.000000 3247.614839 0.999916
centers__0 54.627527 0.733327 0.011941 53.263593 56.069912 2857.383302 1.000063
centers__1 80.087834 0.523151 0.008877 79.104464 81.134251 3013.404106 0.999852
p 0.361858 0.030777 0.000567 0.302382 0.419931 2753.669629 1.000464
sds__0 6.018165 0.585472 0.010211 4.961674 7.231991 2942.626870 0.999759
sds__1 5.947338 0.416616 0.008702 5.151874 6.773165 2566.297885 0.999898

277 rows × 7 columns

pm.autocorrplot(oftrace, varnames=['p', 'centers', 'sds']);

png

oftrace['centers'].mean(axis=0)
array([54.62752749, 80.0878344 ])

We can visualize the two clusters, suitably scales by the category-belonging probability by taking the posterior means. Note that this misses any smearing that might go into making the posterior predictive

from scipy.stats import norm
x = np.linspace(20, 120, 500)
# for pretty colors later in the book.
colors = ["#348ABD", "#A60628"] if oftrace['centers'][-1, 0] > oftrace['centers'][-1, 1] \
    else ["#A60628", "#348ABD"]

posterior_center_means = oftrace['centers'].mean(axis=0)
posterior_std_means = oftrace['sds'].mean(axis=0)
posterior_p_mean = oftrace["p"].mean()

plt.hist(ofdata.waiting, bins=20, histtype="step", normed=True, color="k",
     lw=2, label="histogram of data")
y = posterior_p_mean * norm.pdf(x, loc=posterior_center_means[0],
                                scale=posterior_std_means[0])
plt.plot(x, y, label="Cluster 0 (using posterior-mean parameters)", lw=3)
plt.fill_between(x, y, color=colors[1], alpha=0.3)

y = (1 - posterior_p_mean) * norm.pdf(x, loc=posterior_center_means[1],
                                      scale=posterior_std_means[1])
plt.plot(x, y, label="Cluster 1 (using posterior-mean parameters)", lw=3)
plt.fill_between(x, y, color=colors[0], alpha=0.3)

plt.legend(loc="upper left")
plt.title("Visualizing Clusters using posterior-mean parameters");
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6521: MatplotlibDeprecationWarning: 
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
  alternative="'density'", removal="3.1")

png

A tetchy 3 Gaussian Model

Let us set up our data. Our analysis here follows that of https://colindcarroll.com/2018/07/20/why-im-excited-about-pymc3-v3.5.0/ , and we have chosen 3 gaussians reasonably close to each other to show the problems that arise!

mu_true = np.array([-2, 0, 2])
sigma_true = np.array([1, 1, 1])
lambda_true = np.array([1/3, 1/3, 1/3])
n = 100
from scipy.stats import multinomial
# Simulate from each distribution according to mixing proportion psi
z = multinomial.rvs(1, lambda_true, size=n)
data=np.array([np.random.normal(mu_true[i.astype('bool')][0], sigma_true[i.astype('bool')][0]) for i in z])
sns.distplot(data, bins=50);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_axes.py:6521: MatplotlibDeprecationWarning: 
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
  alternative="'density'", removal="3.1")
//anaconda/envs/py3l/lib/python3.6/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval

png

np.savetxt("data/3gv2.dat", data)
with pm.Model() as mof:
    #p = pm.Dirichlet('p', a=np.array([1., 1., 1.]), shape=3)
    p=[1/3, 1/3, 1/3]

    # cluster centers
    means = pm.Normal('means', mu=0, sd=10, shape=3)


    #sds = pm.HalfCauchy('sds', 5, shape=3)
    sds = np.array([1., 1., 1.])
    
    # latent cluster of each observation
    category = pm.Categorical('category',
                              p=p,
                              shape=data.shape[0])

    # likelihood for each observed value
    points = pm.Normal('obs',
                       mu=means[category],
                       sd=1., #sds[category],
                       observed=data)

with mof:
    tripletrace_full = pm.sample(10000)
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [means]
>CategoricalGibbsMetropolis: [category]
Sampling 2 chains: 100%|██████████| 21000/21000 [01:39<00:00, 210.02draws/s]
The estimated number of effective samples is smaller than 200 for some parameters.
trace_mof=tripletrace_full[3000::7]
#pm.traceplot(trace_mof, varnames=["means", "p", "sds"]);
pm.traceplot(trace_mof, varnames=["means"], combined=True);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning: 
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
  alternative='`bottom`', obj_type='argument')

png

pm.autocorrplot(trace_mof, varnames=['means']);

png

Problems with clusters and sampling

Some of the traces seem ok, but the autocorrelation is quite bad. And there is label-switching .This is because there are major problems with using MCMC for clustering.

AND THIS IS WITHOUT MODELING $p$ OR $\sigma$. It gets much worse otherwise! (it would be better if the gaussians were quite widely separated out).

These are firstly, the lack of parameter identifiability (the so called label-switching problem) and secondly, the multimodality of the posteriors.

We have seen non-identifiability before. Switching labels on the means and z’s, for example, does not change the likelihoods. The problem with this is that cluster parameters cannot be compared across chains: what might be a cluster parameter in one chain could well belong to the other cluster in the second chain. Even within a single chain, indices might swap leading to a telltale back and forth in the traces for long chains or data not cleanly separated.

Also, the (joint) posteriors can be highly multimodal. One form of multimodality is the non-identifiability, though even without identifiability issues the posteriors are highly multimodal.

To quote the Stan manual:

Bayesian inference fails in cases of high multimodality because there is no way to visit all of the modes in the posterior in appropriate proportions and thus no way to evaluate integrals involved in posterior predictive inference. In light of these two problems, the advice often given in fitting clustering models is to try many different initializations and select the sample with the highest overall probability. It is also popular to use optimization-based point estimators such as expectation maximization or variational Bayes, which can be much more efficient than sampling-based approaches.

Some mitigation via ordering in pymc3

But this is not a panacea. Sampling is still very hard.

import theano.tensor as tt
import pymc3.distributions.transforms as tr


with pm.Model() as mof2:
    
    p = [1/3, 1/3, 1/3]

    # cluster centers
    means = pm.Normal('means', mu=0, sd=10, shape=3,
                  transform=tr.ordered,
                  testval=np.array([-1, 0, 1]))


                                         
    # measurement error
    #sds = pm.Uniform('sds', lower=0, upper=20, shape=3)

    # latent cluster of each observation
    category = pm.Categorical('category',
                              p=p,
                              shape=data.shape[0])

    # likelihood for each observed value
    points = pm.Normal('obs',
                       mu=means[category],
                       sd=1., #sds[category],
                       observed=data)

with mof2:
    tripletrace_full2 = pm.sample(10000)
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [means]
>CategoricalGibbsMetropolis: [category]
Sampling 2 chains: 100%|██████████| 21000/21000 [03:42<00:00, 94.44draws/s] 
The acceptance probability does not match the target. It is 0.9255675538672136, but should be close to 0.8. Try to increase the number of tuning steps.
The number of effective samples is smaller than 10% for some parameters.
trace_mof2 = tripletrace_full2[3000::5]
pm.traceplot(trace_mof2, varnames=["means"], combined=True);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning: 
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
  alternative='`bottom`', obj_type='argument')

png

pm.autocorrplot(trace_mof2, varnames=["means"]);

png

Full sampling is horrible, even with potentials

Now lets put Dirichlet based (and this is a strongly centering Dirichlet) prior on the probabilities

from scipy.stats import dirichlet
ds = dirichlet(alpha=[10,10,10]).rvs(1000)
"""
Visualize points on the 3-simplex (eg, the parameters of a
3-dimensional multinomial distributions) as a scatter plot 
contained within a 2D triangle.
David Andrzejewski (david.andrzej@gmail.com)
"""
import numpy as NP
import matplotlib.pyplot as P
import matplotlib.ticker as MT
import matplotlib.lines as L
import matplotlib.cm as CM
import matplotlib.colors as C
import matplotlib.patches as PA

def plotSimplex(points, fig=None, 
                vertexlabels=['1','2','3'],
                **kwargs):
    """
    Plot Nx3 points array on the 3-simplex 
    (with optionally labeled vertices) 
    
    kwargs will be passed along directly to matplotlib.pyplot.scatter    
    Returns Figure, caller must .show()
    """
    if(fig == None):        
        fig = P.figure()
    # Draw the triangle
    l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
                  [0, NP.sqrt(3) / 2, 0, 0], # ycoords
                  color='k')
    fig.gca().add_line(l1)
    fig.gca().xaxis.set_major_locator(MT.NullLocator())
    fig.gca().yaxis.set_major_locator(MT.NullLocator())
    # Draw vertex labels
    fig.gca().text(-0.05, -0.05, vertexlabels[0])
    fig.gca().text(1.05, -0.05, vertexlabels[1])
    fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2])
    # Project and draw the actual points
    projected = projectSimplex(points)
    P.scatter(projected[:,0], projected[:,1], **kwargs)              
    # Leave some buffer around the triangle for vertex labels
    fig.gca().set_xlim(-0.2, 1.2)
    fig.gca().set_ylim(-0.2, 1.2)

    return fig    

def projectSimplex(points):
    """ 
    Project probabilities on the 3-simplex to a 2D triangle
    
    N points are given as N x 3 array
    """
    # Convert points one at a time
    tripts = NP.zeros((points.shape[0],2))
    for idx in range(points.shape[0]):
        # Init to triangle centroid
        x = 1.0 / 2
        y = 1.0 / (2 * NP.sqrt(3))
        # Vector 1 - bisect out of lower left vertex 
        p1 = points[idx, 0]
        x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6)
        y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6)
        # Vector 2 - bisect out of lower right vertex  
        p2 = points[idx, 1]  
        x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6)
        y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6)        
        # Vector 3 - bisect out of top vertex
        p3 = points[idx, 2]
        y = y + (1.0 / NP.sqrt(3) * p3)
      
        tripts[idx,:] = (x,y)

    return tripts


plotSimplex(ds, s=20);

png

The idea behind a Potential is something that is not part of the likelihood, but enforces a constraint by setting the probability to 0 if the constraint is violated. We use it here to give each cluster some membership and to order the means to remove the non-identifiability problem. See below for how its used.

The sampler below has a lot of problems.

with pm.Model() as mofb:
    p = pm.Dirichlet('p', a=np.array([10., 10., 10.]), shape=3)
    # ensure all clusters have some points
    p_min_potential = pm.Potential('p_min_potential', tt.switch(tt.min(p) < .1, -np.inf, 0))
    # cluster centers
    means = pm.Normal('means', mu=0, sd=10, shape=3, transform=tr.ordered,
                  testval=np.array([-1, 0, 1]))

    category = pm.Categorical('category',
                              p=p,
                              shape=data.shape[0])

    # likelihood for each observed value
    points = pm.Normal('obs',
                       mu=means[category],
                       sd=1., #sds[category],
                       observed=data)


with mofb:
    tripletrace_fullb = pm.sample(10000, nuts_kwargs=dict(target_accept=0.95))
Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
>NUTS: [means, p]
>CategoricalGibbsMetropolis: [category]
Sampling 2 chains: 100%|██████████| 21000/21000 [06:13<00:00, 56.23draws/s]
There were 10 divergences after tuning. Increase `target_accept` or reparameterize.
There were 7 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.
trace_mofb = tripletrace_fullb[3000::5]
pm.traceplot(trace_mofb, varnames=["means", "p"], combined=True);
//anaconda/envs/py3l/lib/python3.6/site-packages/matplotlib/axes/_base.py:3604: MatplotlibDeprecationWarning: 
The `ymin` argument was deprecated in Matplotlib 3.0 and will be removed in 3.2. Use `bottom` instead.
  alternative='`bottom`', obj_type='argument')

png

pm.summary(trace_mofb)
mean sd mc_error hpd_2.5 hpd_97.5 n_eff Rhat
category__0 1.406071 0.575480 0.014954 1.000000 2.000000 1362.583316 1.000115
category__1 0.228571 0.421610 0.010938 0.000000 1.000000 1306.422985 1.000472
category__2 0.395357 0.496897 0.011826 0.000000 1.000000 1674.615993 1.000276
category__3 0.108929 0.311550 0.008695 0.000000 1.000000 1307.190266 1.001074
category__4 0.824643 0.552559 0.011552 0.000000 2.000000 2328.351050 0.999864
category__5 0.280714 0.452516 0.011192 0.000000 1.000000 1662.033357 1.000363
category__6 1.948929 0.221760 0.004981 1.000000 2.000000 1702.198664 1.002138
category__7 1.712143 0.473583 0.012564 1.000000 2.000000 1256.089935 0.999754
category__8 0.900357 0.569837 0.011499 0.000000 2.000000 2245.021084 0.999675
category__9 1.739286 0.454218 0.012043 1.000000 2.000000 1421.562668 1.000357
category__10 0.432500 0.506121 0.013448 0.000000 1.000000 1749.523067 0.999755
category__11 1.758214 0.445337 0.011746 1.000000 2.000000 1291.113963 1.000946
category__12 1.916071 0.282386 0.006534 1.000000 2.000000 1765.304573 1.000003
category__13 1.866429 0.342284 0.009551 1.000000 2.000000 1477.039629 0.999995
category__14 0.225000 0.419289 0.010990 0.000000 1.000000 1472.910821 1.001048
category__15 0.459286 0.513864 0.011694 0.000000 1.000000 1619.209679 1.000851
category__16 0.649643 0.528238 0.011004 0.000000 1.000000 2333.805817 0.999647
category__17 0.440000 0.511273 0.012340 0.000000 1.000000 1803.682268 1.000142
category__18 1.989286 0.102954 0.002396 2.000000 2.000000 2055.892030 0.999691
category__19 1.957143 0.202535 0.004574 2.000000 2.000000 1996.770513 1.001435
category__20 0.445714 0.510514 0.013029 0.000000 1.000000 1827.687813 0.999651
category__21 1.989643 0.101242 0.002039 2.000000 2.000000 2505.025728 0.999954
category__22 1.930000 0.255147 0.006387 1.000000 2.000000 1414.514214 1.000427
category__23 1.332143 0.587704 0.014147 0.000000 2.000000 1797.023673 1.001351
category__24 1.733571 0.457963 0.012479 1.000000 2.000000 1184.444494 1.000431
category__25 1.515714 0.556555 0.015364 1.000000 2.000000 1331.826974 0.999807
category__26 0.425000 0.500089 0.012622 0.000000 1.000000 1455.116136 0.999694
category__27 1.979643 0.141219 0.002913 2.000000 2.000000 2378.339776 1.003031
category__28 1.722143 0.468214 0.012763 1.000000 2.000000 1469.880006 1.000669
category__29 0.873214 0.554202 0.011719 0.000000 2.000000 2572.986079 1.000211
... ... ... ... ... ... ... ...
category__76 0.473929 0.516892 0.012308 0.000000 1.000000 1568.551705 1.000044
category__77 1.940714 0.237666 0.005468 1.000000 2.000000 2057.523309 1.000546
category__78 1.552857 0.545821 0.014423 1.000000 2.000000 1278.184527 1.001185
category__79 1.197143 0.603317 0.012732 0.000000 2.000000 2419.356924 0.999812
category__80 0.119643 0.325642 0.008502 0.000000 1.000000 1457.209338 0.999644
category__81 0.671429 0.535762 0.011043 0.000000 1.000000 2416.678137 0.999991
category__82 1.471786 0.560539 0.014711 1.000000 2.000000 1408.654254 0.999643
category__83 0.687857 0.540894 0.011061 0.000000 1.000000 2184.065312 1.000089
category__84 1.832500 0.379117 0.009391 1.000000 2.000000 1442.318862 0.999651
category__85 0.238929 0.428101 0.010072 0.000000 1.000000 1580.222999 0.999649
category__86 1.001786 0.574764 0.011643 0.000000 2.000000 2431.781980 0.999884
category__87 0.714286 0.546884 0.011661 0.000000 1.000000 2298.065679 0.999726
category__88 1.627143 0.519869 0.014501 1.000000 2.000000 1153.871436 0.999673
category__89 1.979286 0.142426 0.002728 2.000000 2.000000 2325.559161 0.999668
category__90 1.987143 0.112658 0.002501 2.000000 2.000000 1955.167689 1.000005
category__91 1.411429 0.573845 0.012387 1.000000 2.000000 2121.534066 0.999742
category__92 0.287857 0.454339 0.011087 0.000000 1.000000 1656.472820 1.000951
category__93 0.120714 0.325795 0.009032 0.000000 1.000000 1414.023539 0.999648
category__94 0.817857 0.560327 0.011872 0.000000 2.000000 1951.110342 0.999649
category__95 1.003571 0.596407 0.012872 0.000000 2.000000 2002.102565 0.999679
category__96 0.076786 0.266251 0.007683 0.000000 1.000000 1365.135665 0.999645
category__97 1.140000 0.584661 0.010782 0.000000 2.000000 2331.457873 0.999697
category__98 1.993571 0.079920 0.001462 2.000000 2.000000 2601.475216 1.000921
category__99 0.199643 0.400624 0.010216 0.000000 1.000000 1420.458431 1.000063
p__0 0.360226 0.081588 0.003317 0.204306 0.519089 519.487415 1.002060
p__1 0.325225 0.071761 0.001919 0.172866 0.459570 1294.128430 0.999978
p__2 0.314550 0.065550 0.002548 0.189302 0.444908 587.062612 1.001335
means__0 -1.886959 0.263363 0.007660 -2.388613 -1.365366 1295.486095 1.002200
means__1 -0.407361 0.565056 0.029250 -1.451171 0.683029 286.606094 1.002899
means__2 1.948775 0.303411 0.009837 1.418676 2.599857 791.034984 1.001906

106 rows × 7 columns

Making Problems go away

A lot will go away when identifiability improves through separated gaussians. But that changes the data. If we want any further improvement on this data, we are going to have to stop sampling so many discrete categoricals. And for that we will need a marginalization trick.