• 目标
  • 期望最大化
  • 交互式演示
  • 首页
  • 文章
  • 笔记
  • 书架
  • 作者
🇺🇸 en 🇫🇷 fr 🇮🇳 ml

Nathaniel Thomas

交互式高斯混合模型

2024年12月6日

目标

假设我们有一个特征数据集,但没有标签。如果我们知道(或猜测)数据集中有 K 个类别,我们可以将数据集建模为 K 个类别条件高斯分布的加权平均。这就是高斯混合模型所做的。

我们假设模型由 θ={πk​,μk​,σk2​}k=1K​ 参数化,其中 πk​ 决定了模型中第 k 个高斯的权重。

由于我们的数据集 D={xi​}i=1N​ 是独立同分布的,其对数似然为

L(θ)​=i=1∑N​log(p(xi​))=i=1∑N​log(k=1∑K​N(xi​∣μk​,σk2​)πk​)​

期望最大化

为了找到使数据似然最大化的 μk​,σk2​,我们将使用以下过程:

  1. 计算 θ={πk​,μk​,σk2​}k=1K​ 的初始猜测值

  2. 计算 xi​ 属于类别 k 的似然。我们将其记为 rik​ 或 xi​ 对 k 的责任

rik​=∑j=1K​πj​⋅N(xi​∣μj​,σj2​)πk​⋅N(xi​∣μk​,σk2​)​
  1. 我们更新
    • 权重 πk​ 为高斯分布 k 的平均责任
    • 均值 μk​ 为数据点的加权平均,权重为所有 i 的 rik​
    • 方差 σk2​ 为数据点相对于新 μk​ 的加权方差,权重同样为 rik​
πk​μk​σk2​​=N∑i=1N​rik​​=∑i=1N​rik​∑i=1N​rik​⋅xi​​=∑i=1N​rik​∑i=1N​rik​⋅(xi​−μk​)2​​

注意这个过程与核回归的相似之处! 在这种情况下,核函数是 rik​,它定义了可能属于类别 k 的特征 xi​ 的邻域。

步骤 2 和 3 重复进行,直到权重收敛。

交互式演示

以下是EM算法的交互式演示。数据由 K个高斯分布生成,其均值、方差和权重是随机选择的。然后,使用 K个高斯分布的GM模型对数据进行拟合。

你可以反复点击开始。建议在桌面端使用。

Iteration: 0
$k$ True $(\mu_k, \sigma_k^2)$ Est $(\hat \mu_k, \hat \sigma_k^2)$ True $\pi_k$ Est $\hat{\pi}_k$
代码

以下是JavaScript中的核心算法代码。由于我省略了绘图代码和HTML/CSS,这段代码不会直接重现上面的图表。你可以使用检查元素查看完整内容。

function randn() {
  let u = 0, v = 0;
  while(u === 0) u = Math.random();
  while(v === 0) v = Math.random();
  return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
}

function gaussianPDF(x, mean, variance) {
  const std = Math.sqrt(variance);
  const coeff = 1.0 / (std * Math.sqrt(2 * Math.PI));
  const exponent = -0.5 * Math.pow((x - mean)/std, 2);
  return coeff * Math.exp(exponent);
}

function generateSeparatedMeans(C) {
  let candidate = [];
  for (let i = 0; i < C; i++) {
    candidate.push(Math.random());
  }
  candidate.sort((a,b) => a - b);
  let means = candidate.map(x => -5 + x*10);
  for (let i = 1; i < C; i++) {
    if (means[i] - means[i-1] < 0.5) {
      means[i] = means[i-1] + 0.5;
    }
  }
  return means;
}

function generateData(C, N=1000) {
  let means = generateSeparatedMeans(C);
  let variances = [];
  let weights = [];
  for (let i = 0; i < C; i++) {
    variances.push(0.5 + 1.5*Math.random());
    weights.push(1.0/C);
  }

  let data = [];
  for (let i = 0; i < N; i++) {
    const comp = Math.floor(Math.random() * C);
    const x = means[comp] + Math.sqrt(variances[comp])*randn();
    data.push(x);
  }
  return {data, means, variances, weights};
}

function decentInitialGuess(C, data) {
  const N = data.length;
  let means = [];
  let variances = [];
  let weights = [];
  for (let c = 0; c < C; c++) {
    means.push(data[Math.floor(Math.random()*N)]);
    variances.push(1.0);
    weights.push(1.0/C);
  }
  return {means, variances, weights};
}

function emGMM(data, C, maxIter=100, tol=1e-4) {
  const N = data.length;
  let init = decentInitialGuess(C, data);
  let means = init.means.slice();
  let variances = init.variances.slice();
  let weights = init.weights.slice();

  let logLikOld = -Infinity;
  let paramsHistory = [];

  for (let iter = 0; iter < maxIter; iter++) {
    let resp = new Array(N).fill(0).map(() => new Array(C).fill(0));
    for (let i = 0; i < N; i++) {
      let total = 0;
      for (let c = 0; c < C; c++) {
        const val = weights[c]*gaussianPDF(data[i], means[c], variances[c]);
        resp[i][c] = val;
        total += val;
      }
      for (let c = 0; c < C; c++) {
        resp[i][c] /= (total + 1e-15);
      }
    }

    for (let c = 0; c < C; c++) {
      let sumResp = 0;
      let sumMean = 0;
      let sumVar = 0;
      for (let i = 0; i < N; i++) {
        sumResp += resp[i][c];
        sumMean += resp[i][c]*data[i];
      }
      const newMean = sumMean / (sumResp + 1e-15);

      for (let i = 0; i < N; i++) {
        let diff = data[i] - newMean;
        sumVar += resp[i][c]*diff*diff;
      }
      const newVar = sumVar/(sumResp + 1e-15);

      means[c] = newMean;
      variances[c] = Math.max(newVar, 1e-6);
      weights[c] = sumResp/N;
    }

    let logLik = 0;
    for (let i = 0; i < N; i++) {
      let p = 0;
      for (let c = 0; c < C; c++) {
        p += weights[c]*gaussianPDF(data[i], means[c], variances[c]);
      }
      logLik += Math.log(p + 1e-15);
    }

    paramsHistory.push({
      means: means.slice(),
      variances: variances.slice(),
      weights: weights.slice()
    });

    if (Math.abs(logLik - logLikOld) < tol) {
      break;
    }
    logLikOld = logLik;
  }
  return paramsHistory;
}

←
局部近似
熵的第一性原理推导
→

back to top