• Goal
  • Expectation Maximization
  • Interactive Demo
  • Home
  • Posts
  • Notes
  • Bookshelf
  • Author
๐Ÿ‡ซ๐Ÿ‡ท fr ๐Ÿ‡จ๐Ÿ‡ณ zh ๐Ÿ‡ฎ๐Ÿ‡ณ ml

Nathaniel Thomas

Interactive Gaussian Mixture Models

December 6, 2024

Goal

Suppose we have a dataset of features, but no labels. If we know (or guess) that there are K classes in the dataset, we could model the dataset as the weighted average of K classโ€“conditional Gaussians. This is what Gaussian Mixture Models do.

We assume that the model is parameterized by ฮธ={ฯ€kโ€‹,ฮผkโ€‹,ฯƒk2โ€‹}k=1Kโ€‹, where ฯ€kโ€‹ determines the weight of the kth Gaussian in the model.

Since our dataset D={xiโ€‹}i=1Nโ€‹ is i.i.d., its logโ€“likelihood is

L(ฮธ)โ€‹=i=1โˆ‘Nโ€‹log(p(xiโ€‹))=i=1โˆ‘Nโ€‹log(k=1โˆ‘Kโ€‹N(xiโ€‹โˆฃฮผkโ€‹,ฯƒk2โ€‹)ฯ€kโ€‹)โ€‹

Expectation Maximization

To find the ฮผkโ€‹,ฯƒk2โ€‹ that maximizes the likelihood of our data, we will use the following process:

  1. Compute initial guesses for ฮธ={ฯ€kโ€‹,ฮผkโ€‹,ฯƒk2โ€‹}k=1Kโ€‹

  2. Compute the likelihood of xiโ€‹ belonging to class k. We denote this rikโ€‹ or responsibility of xiโ€‹ for k

rikโ€‹=โˆ‘j=1Kโ€‹ฯ€jโ€‹โ‹…N(xiโ€‹โˆฃฮผjโ€‹,ฯƒj2โ€‹)ฯ€kโ€‹โ‹…N(xiโ€‹โˆฃฮผkโ€‹,ฯƒk2โ€‹)โ€‹
  1. We update
    • weights ฯ€kโ€‹ to be the average responsibility for Gaussian k
    • means ฮผkโ€‹ to be the average of the datapoints, weighted by rikโ€‹ for all i
    • variances ฯƒk2โ€‹ to be the variance of the datapoints from the new ฮผkโ€‹, similarly weighted by rikโ€‹
ฯ€kโ€‹ฮผkโ€‹ฯƒk2โ€‹โ€‹=Nโˆ‘i=1Nโ€‹rikโ€‹โ€‹=โˆ‘i=1Nโ€‹rikโ€‹โˆ‘i=1Nโ€‹rikโ€‹โ‹…xiโ€‹โ€‹=โˆ‘i=1Nโ€‹rikโ€‹โˆ‘i=1Nโ€‹rikโ€‹โ‹…(xiโ€‹โˆ’ฮผkโ€‹)2โ€‹โ€‹

Note the similarity between this process and Kernel regression! In this case, the kernel function is rikโ€‹, which defines a neighborhood of features xiโ€‹ that likely belong to class k.

Steps 2 and 3 are repeated until the weights converge.

Interactive Demo

Below is an interactive demo of the EM algorithm. The data is generated from K Gaussians, whose means, variances, and weights are randomly selected. Then, a GM model of K Gaussians is fit to the data.

You can repeatedly hit Start. Desktop recommended.

Iteration: 0
$k$ True $(\mu_k, \sigma_k^2)$ Est $(\hat \mu_k, \hat \sigma_k^2)$ True $\pi_k$ Est $\hat{\pi}_k$
Code

Here is the core algorithm code in JavaScript. This won’t directly reproduce the plot above since I omitted the plotting code and HTML/CSS. You can use Inspect Element to see the whole thing.

function randn() {
  let u = 0, v = 0;
  while(u === 0) u = Math.random();
  while(v === 0) v = Math.random();
  return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
}

function gaussianPDF(x, mean, variance) {
  const std = Math.sqrt(variance);
  const coeff = 1.0 / (std * Math.sqrt(2 * Math.PI));
  const exponent = -0.5 * Math.pow((x - mean)/std, 2);
  return coeff * Math.exp(exponent);
}

function generateSeparatedMeans(C) {
  let candidate = [];
  for (let i = 0; i < C; i++) {
    candidate.push(Math.random());
  }
  candidate.sort((a,b) => a - b);
  let means = candidate.map(x => -5 + x*10);
  for (let i = 1; i < C; i++) {
    if (means[i] - means[i-1] < 0.5) {
      means[i] = means[i-1] + 0.5;
    }
  }
  return means;
}

function generateData(C, N=1000) {
  let means = generateSeparatedMeans(C);
  let variances = [];
  let weights = [];
  for (let i = 0; i < C; i++) {
    variances.push(0.5 + 1.5*Math.random());
    weights.push(1.0/C);
  }

  let data = [];
  for (let i = 0; i < N; i++) {
    const comp = Math.floor(Math.random() * C);
    const x = means[comp] + Math.sqrt(variances[comp])*randn();
    data.push(x);
  }
  return {data, means, variances, weights};
}

function decentInitialGuess(C, data) {
  const N = data.length;
  let means = [];
  let variances = [];
  let weights = [];
  for (let c = 0; c < C; c++) {
    means.push(data[Math.floor(Math.random()*N)]);
    variances.push(1.0);
    weights.push(1.0/C);
  }
  return {means, variances, weights};
}

function emGMM(data, C, maxIter=100, tol=1e-4) {
  const N = data.length;
  let init = decentInitialGuess(C, data);
  let means = init.means.slice();
  let variances = init.variances.slice();
  let weights = init.weights.slice();

  let logLikOld = -Infinity;
  let paramsHistory = [];

  for (let iter = 0; iter < maxIter; iter++) {
    let resp = new Array(N).fill(0).map(() => new Array(C).fill(0));
    for (let i = 0; i < N; i++) {
      let total = 0;
      for (let c = 0; c < C; c++) {
        const val = weights[c]*gaussianPDF(data[i], means[c], variances[c]);
        resp[i][c] = val;
        total += val;
      }
      for (let c = 0; c < C; c++) {
        resp[i][c] /= (total + 1e-15);
      }
    }

    for (let c = 0; c < C; c++) {
      let sumResp = 0;
      let sumMean = 0;
      let sumVar = 0;
      for (let i = 0; i < N; i++) {
        sumResp += resp[i][c];
        sumMean += resp[i][c]*data[i];
      }
      const newMean = sumMean / (sumResp + 1e-15);

      for (let i = 0; i < N; i++) {
        let diff = data[i] - newMean;
        sumVar += resp[i][c]*diff*diff;
      }
      const newVar = sumVar/(sumResp + 1e-15);

      means[c] = newMean;
      variances[c] = Math.max(newVar, 1e-6);
      weights[c] = sumResp/N;
    }

    let logLik = 0;
    for (let i = 0; i < N; i++) {
      let p = 0;
      for (let c = 0; c < C; c++) {
        p += weights[c]*gaussianPDF(data[i], means[c], variances[c]);
      }
      logLik += Math.log(p + 1e-15);
    }

    paramsHistory.push({
      means: means.slice(),
      variances: variances.slice(),
      weights: weights.slice()
    });

    if (Math.abs(logLik - logLikOld) < tol) {
      break;
    }
    logLikOld = logLik;
  }
  return paramsHistory;
}

โ†
Local Approximation
Entropy from First Principles
โ†’

back to top