Optimizing GMM parameters using EM
A Gaussian Mixture Model (GMM) models data as a finite mixture of Gaussians. It is often used to perform clustering. In this case, the number of Gaussians K is the number of clusters and is initially selected. Compared to K-means, GMM allows clusters with uneven variance and density.
Parameters are usually estimated using an Expectation-Maximization (EM) algorithm, its aim being to iteratively increase likelihood of the dataset. GMM is actually a perfect model to understand how EM is working.
Fig. 1. Clustering a 2D dataset into 3 clusters using GMM. Left: Dataset before clustering. Right: Clustered data with position of fitted mean for each Gaussian.
We begin by describing GMM and list parameters of the model. We continue by applying EM to GMM and derive update formulas. Then, the algorithm to cluster data is given. We finally illustrate the clustering process on a simple example. This post does not detail how EM is working. If you are interested about theoretical considerations, please read this post.
What is GMM?
Let x=(xi)i∈{1,…n} a dataset of Rd.
Let i any element of {1,…n}. We assume that xi∈Rd has been sampled from a random variable Xi. We assume that Xi follows a probability distribution with a certain density. The density of Xi in any x∈Rd is written as follows:
p(Xi=x).In addition, we assume that xi is labeled with a certain z(true)i∈{1,…,K}, where K is a fixed integer. Those labels exist (and are fixed), but we only observe xi, without explicit knowledge of the corresponding label z(true)i. The underlying random variable to model the label is noted Zi, and the probability to be labeled k∈{1,…,K} is written as follows:
P(Zi=k).We say that Zi is a latent variable. Using the law of total probability, we can reveal the latent variable (in the formula, x∈Rd):
p(Xi=x)=K∑k=1p(Xi=x|Zi=k)×P(Zi=k).GMM assumes that three hypotheses are verified:
- The vector of couples (Xi,Zi)i forms an independent vector over i,
- Each record belongs to a cluster Zi=k with probability πk (with πk>0),
- Each conditional variable (Xi∣Zi=k) follows a Gaussian distribution with mean mk and covariance matrix Σk.
We let f(m,Σ) the density function of a Gaussian with parameters m and Σ on Rd. Using hypotheses 2 and 3, the last equation is rewritten as follows (for all i,x):
p(Xi=x)=K∑k=1f(mk,Σk)(xi)×πk.Unknown (fixed) parameters of the model are grouped together into:
θ(true):=(π(true)k,m(true)k,Σ(true)k)k∈{1,…,K}.The chosen strategy to estimate θ(true) is to maximize the log-likelihood of observed data x:=(x1,…,xn), as defined by the density of probability to observe x given θ:
logL(θ;x):=logpθ((X1,…,Xn)=(x1,…,xn)).Using the three hypotheses of GMM, we obtain:
logL(θ;x)=logn∏i=1pθ(Xi=xi)=n∑i=1logpθ(Xi=xi)=n∑i=1log[K∑k=1pθ(Xi=xi|Zi=k)Pθ(Zi=k)]=n∑i=1log[K∑k=1f(mk,Σk)(xi)×πk]However, this log-likelihood function is non-convex (as a function of θ) and direct optimization is intractable (see this post for a discussion). We introduce EM to circumvent this problem (other methods could work, see this post for a discussion).
Applying EM to GMM
We assume that some parameters θ(t) have been selected (for a certain t≥0). We would like to update parameters and find out θ(t+1) using EM algorithm.
We define for all θ:
Q(θ|θ(t)):=∑zlogpθ(x,z)Pθ(t)(z∣x).The aim of EM is to maximize the function Q in θ. Please read section “The EM algorithm” of the EM post to understand why we have selected this function. From the last paragraph of the EM post, we also have:
Q(θ|θ(t))=∑z[logpθ(x|z)+logpθ(z)]pθ(t)(z,x)∑z′pθ(t)(z′,x).We recall how parameters decompose into 3 terms for GMM:
θ:=(πk,mk,Σk)k∈{1,…,K},In the following equalities, we use hypothesis 1 and then hypotheses 2 and 3 of GMM:
Q(θ|θ(t))=∑z[logpθ(x|z)+logpθ(z)]pθ(t)(z,x)∑z′pθ(t)(z′,x)=N∑i=1K∑zi=1[logpθ(xi|zi)+logpθ(zi)]pθ(t)(zi,xi)∑z′ipθ(t)(z′i,xi)=N∑i=1K∑k=1[logpθ(xi|k)+logpθ(k)]pθ(t)(k,xi)∑Kk′=1pθ(t)(k′,xi)=N∑i=1K∑k=1[logf(mk,Σk)(xi)+logπk]f(m(t)k,Σ(t)k)(xi)π(t)k∑Kk′=1f(m(t)k′,Σ(t)k′)(xi)π(t)k′.We define:
T(t)k,i:=Pθ(t)(Zi=k|Xi=xi)=f(m(t)k,Σ(t)k)(xi)π(t)k∑Kk′=1f(m(t)k′,Σ(t)k′)(xi)π(t)k′.We use explicit formula for the Gaussian distribution (for all x∈Rd)
f(m,Σ)(x)=1(2π)K/2√detΣexp(−12(x−m)TΣ−1(x−m))and obtain:
Q(θ|θ(t))=N∑i=1K∑k=1[logf(mk,Σk)(xi)+logπk]T(t)k,i=N∑i=1K∑k=1[−K2log2π−12logdetΣk−12(xi−mk)TΣ−1k(xi−mk)+logπk]T(t)k,iFrom this shape, we can separate maximization of each couple (mk,Σk) (for k∈{1,…,K}) and maximization of the set (πk)k.
For the mean mk
From previous expression, we can perform maximization for each fixed k. Some terms have no dependence on mk∈Rd, so we need to maximize:
A(mk):=−12N∑i=1[(xi−mk)TΣ−1k(xi−mk)]T(t)k,iWe take the gradient with respect to mk (see formula (86) of the matrix cookbook and this post to remember how gradient is calculated in this case):
∇mkA(mk)=−12N∑i=1[−2Σ−1k(xi−mk)]T(t)k,i=N∑i=1Σ−1k(xi−mk)T(t)k,i.We have ∇mkA(mk)=0 if and only if: ∑Ni=1(xi−mk)T(t)k,i=0 from which we deduce:
mk=∑Ni=1xiT(t)k,i∑Ni=1mkT(t)k,i.Furthermore, Hessian matrix of A(mk) is given by:
−(N∑i=1T(t)k,i)Σ−1kwhich is negative-definite.
Conclusion: We select m(t+1)k:=∑Ni=1xiT(t)k,i∑Ni=1mkT(t)k,i and A(.) is maximized in m(t+1)k.
For the matrix of variance-covariance Σk
From previous expression, we can also perform maximization of Σk for each fixed k.
First, it is easier to differentiate with respect to Σ−1k, so we let Λk:=Σ−1k and maximize:
B(Λk):=−12N∑i=1T(t)k,ilogdetΛ−1k−12N∑i=1[(xi−m(t+1)k)TΛk(xi−m(t+1)k)]T(t)k,iThen, we want to differentiate with respect to the matrix Λk of shape d×d. This is quite ambiguous: We can either decide to see the matrix as a vector of length d2, or to see it as a vector of length d(d+1)/2 (because Λk is symmetric and many coefficients are identical). Choosing one or another way to differentiate will change the formula for ∇ΛkB(Λk), which are both valid and give the same maximized variance-covariance matrix Λ(t+1)k.
Let’s do the most simple computations (seeing Λk as a vector of length d2).
Using formula (57) of the matrix cookbook and that Λk is symmetric and positive-definite:
∇ΛklogdetΛ−1k=−∇Λklogdet|Λk|=−Λ−1k.Using formula (70) of the matrix cookbook:
∇ΛkzTΛkz=zzT.We select z:=xi−m(t+1)k and obtain:
∇ΛkB(Λk)=−12N∑i=1T(t)k,i(−Λ−1k)−12N∑i=1[(xi−m(t+1)k)(xi−m(t+1)k)T]T(t)k,i.We have ∇ΣkB(Σk)=0 if and only if:
Λ−1k(N∑i=1T(t)k,i)=N∑i=1(xi−m(t+1)k)(xi−m(t+1)k)TT(t)k,i.And so:
Σk=∑Ni=1(xi−m(t+1)k)(xi−m(t+1)k)TT(t)k,i∑Ni=1T(t)k,i.This matrix is positive-definite.
Furthermore, Hessian matrix of B(Λk) is given by:
∇ΛkB(Λk)=−12N∑i=1T(t)k,iΛ−2k.which is negative-definite.
Conclusion: We select Σ(t+1)k:=∑Ni=1(xi−m(t+1)k)(xi−m(t+1)k)TT(t)k,i∑Ni=1T(t)k,i. and B(.) is maximized in Σ(t+1)k.
For the probabilities (πk)k
Probabilities (πk)k are considered all together since there is a constraint: The sum of πk over k must be 1. We remove this constraint using: πK=1−(π1+…+πK−1). From previous expression, we only need to maximize:
C((πk)k∈{1,…,K−1}):=N∑i=1K∑k=1T(t)k,ilogπk.We let Sk:=∑Ni=1T(t)k,i and rewrite:
C((πk)k∈{1,…,K−1})=K−1∑k=1Sklogπk+SKlog(1−(π1+…+πK−1)).For all k∈{1,…,K−1}:
∇πkC((πk)k∈{1,…,K−1})=Skπk−SKπKAnd ∇πkC((πk)k∈{1,…,K−1})=0 if and only if πk=SkSKπK.
Summing on all k∈{1,…,K−1}, we obtain:
1−πK=∑K−1k=1SkSKπK and so: πK=SK∑Kk=1Sk.
It follows for all k∈{1,…,K}:
πk=Sk∑Kk=1Sk=∑Ni=1T(t)k,i∑Kk′=1∑Ni=1T(t)k′,i.Furthermore, Hessian matrix of C((πk)k∈{1,…,K−1}) is a diagonal matrix with diagonal coefficients given by:
−Skπ2k,which is a negative-definite matrix.
Conclusion: We select π(t+1)k:=∑Ni=1T(t)k,i∑Kk′=1∑Ni=1T(t)k′,i for all k and C(.) is maximized in (π(t+1)k)k.
Algorithm to cluster data
Let x=(xi)i∈{1,…n} a dataset of Rd and K an integer.
Step 0
We define initial parameters. For k∈{1,…,K}:
- π(0)k=1/K,
- Σ(0)k the identity matrix of size K×K, and
- (m(0)k)k∈{1,…,K} some initial positions obtained with K-means.
Step t to t+1
Let f(m,Σ) the density function of a Gaussian with parameters m and Σ on Rd.
Let for all k,i:
T(t)k,i:=f(m(t)k,Σ(t)k)(xi)π(t)k∑Kk′=1f(m(t)k′,Σ(t)k′)(xi)π(t)k′.Let for all k:
m(t+1)k:=∑Ni=1xiT(t)k,i∑Ni=1mkT(t)k,i,We repeat this step until convergence (see this article for theoretical results of convergence). In general there is no problem for convergence to a local maxima, however it is possible to build some pathological cases.
Clustering
Given estimated parameters θ(∞)=(m(∞)k,Σ(∞)k,π(∞)k)k,
we compute the density for xi to be labeled k (for all xi,k):
pθ(∞)(Xi=xi,Zi=k)=f(m(∞)k,Σ(∞)k)(xi)×π(∞)k.Hard label for xi is estimated by taking argmaxkpθ(∞)(Xi=xi,Zi=k).
Illustration of the clustering process
We propose to cluster a two-dimensional dataset into 3 clusters using GMM. The dataset is plotted in Fig. 2 (a). We initialize parameters θ(0) with K-means (related clustering shown in Fig. 2 (b)). We update parameters (see Fig. 2 (c) for clustering related to θ(1)) until convergence θ(∞) (see Fig. 2 (d)).
Fig. 2. Clustering dataset into 3 clusters using GMM. From left to right: (a) Dataset before clustering; (b) Initialization with K-means; (c) Step 1; (d) GMM clustering after convergence. On each figure from (b) to (d), one color represents one cluster, and mean position of each cluster is represented with a cross.
We summarize evolution of the parameters along steps. Cluster 1 is the green one on the left, cluster 2 is the orange one on the top, cluster 3 is the purple one on the right.
In this example, mean positions of clusters do not move a lot between K-means and GMM clustering:
t |
m(t)1 |
m(t)2 |
m(t)3 |
0 |
[−0.95−2.94] |
[1.652.93] |
[2.97−2.03] |
1 |
[−0.94−2.93] |
[1.662.93] |
[2.96−2.03] |
∞ |
[−0.88−2.01] |
[1.943.00] |
[2.98−2.03] |
GMM successfully considers uneven variance in each cluster. For example, variance of the second axis for cluster 1 has increased a lot (from 1 to 8.28), contrary to the second axis of cluster 2 (from 1 to 0.04):
t |
Σ(t)1 |
Σ(t)2 |
Σ(t)3 |
0 |
[1001] |
[1001] |
[1001] |
1 |
[0.91−0.14−0.144.90] |
[1.640.130.130.32] |
[0.30−0.01−0.010.18] |
∞ |
[0.97−0.09−0.098.28] |
[0.93−0.02−0.020.04] |
[0.26−0.01−0.010.16] |
GMM successfully considers uneven density in each cluster. For example, estimated proportion of elements in cluster 3 has increased from 1/K to 0.55:
t |
π(t)1 |
π(t)2 |
π(t)3 |
0 |
0.33 |
0.33 |
0.33 |
1 |
0.15 |
0.30 |
0.55 |
∞ |
0.18 |
0.27 |
0.55 |
Likelihood of the dataset has increased from −4144.924 to −2966.941 after convergence. In this case, EM algorithm has reached MLE.
Evolution of likelihood as the number of steps is shown in Fig. 3.
Fig. 3. Evolution of likelihood as the number of steps until convergence using EM algorithm.
Note 1. Dataset has been simulated as mixture of Gaussians, where true means for clusters are [−1−2], [23], [3−2]; true matrices of variance-covariance are [1009], [1000.04], [0.25000.16]; and true proportions are 0.18, 0.27 and 0.55. The likelihood of the set using true parameters is −2979.822.
Note 2. Since we are using K-means for initialization, it may be useful to normalize data before using GMM clustering.
References
-
The Matrix Cookbook. I’ve just discovered it, and it is really useful for reference,
-
English wikipedia about EM. Wikipedia gives concise formulas,