CONTENTS

交互式MNIST数字识别器

在画布上绘制数字,观察AI如何猜测它是什么!

Try drawing a digit on the canvas!

1v1 Least Squares (n/a ms)

Fully Connected Network (n/a ms)
Convolutional Network (n/a ms)

本文探讨了三种以不同方式处理MNIST数据集的基础模型,并比较了它们的特性、优势与不足。您可以通过上方交互界面体验每个模型,并在条形图中查看它们的输出结果。

任务

对于不熟悉机器学习的人来说,将图像转换为数字可能看起来是一项艰巨的任务。然而,如果我们以如下方式思考这个问题,事情就会变得简单:

一张灰度图本质上就是一个像素亮度的网格,这些亮度值是实数。也就是说,每张图像都是集合 中的一个元素,其中 分别是图像的宽度和高度。因此,如果我们能找到一个从 映射到 的函数 ,我们就能解决这个问题。

为此,我们使用训练图像 和标签 来构建一个模型。

最小二乘法

该方法涉及为从类别 0..9 中选出的每一对不同的 ,创建 45 个从 的线性映射,用于推断图像最可能属于 还是 。我们可以使用线性代数来最小化均方误差(MSE)。首先,我们不直接处理 中的图像,而是将其“展平”到

的权重定义为 ,这是一个长度为 的向量。为了得到模型的输出,我们计算

其中 是一个数字。

我们希望在所有 个样本上最小化 MSE

为此,我们创建一个新矩阵 ,它仅包含属于类别 的图像以及一列用于偏置的 ,以及一个矩阵 ,它同样仅包含 中的标签,但将 替换为 ,将 替换为

现在,我们的问题被简化为

其解由 给出,其中 是矩阵 的伪逆(证明留给读者作为练习 😁)。

一旦我们获得了所有 对(共 45 对)的 ,我们就可以将我们期望的函数 表示为

def f(x):
    score = [0] * 10
    for i, j, f_ij in pair_functions:
        out_ij = f_ij(x)
        if out_ij > 0:
            score[i] += 1
            score[j] -= 1
        else:
            score[j] += 1
            score[i] -= 1
    return argmax(score)

45 个模型中的每一个都为其 “投票”。上面条形图中显示的就是 score 数组。

全连接网络

全连接网络(FCN)是一个比最小二乘模型大得多的模型。我们不再是将标签投影到数据的主子空间上,而是可以直接学习从输入空间到输出空间的映射。

对于单层网络,我们假设 可以近似为

其中 是某个非线性函数。我们可以通过梯度下降学习矩阵 ,使得在局部邻域内误差(分类交叉熵)最小化。在演示中,我们使用一个2层网络,先将图像映射到 ,再将结果映射到 。这表示为

其中我们需要学习矩阵 。在我们的例子中, ,并且

的输出转换为概率分布,如上方的条形图所示。

卷积网络

上述两种模型的一个局限在于,它们无法像人类那样感知视觉特征。例如,手写数字1无论出现在画布的哪个位置,它始终是 。然而,由于LS和FCN模型不具备空间或邻近关系的概念,它们只会简单地指向最可能拥有完全一致像素分布的类别。

为此,我们引入卷积运算。卷积操作接收一张图像和一个卷积核,将卷积核在图像上滑动计算,生成输出图像,该图像的每个像素值是输入图像局部区域与卷积核权重的加权和。

请注意卷积如何编码空间信息——这是普通全连接网络所不具备的。由于相邻像素通常高度相关,我们可以通过最大池化对卷积输出进行下采样,同时保留大部分信息。将图像通过一系列(训练得到的)卷积核处理后,我们得到一组矩阵,这些矩阵表示已学习的空间特征是否存在。最后,我们将这些矩阵展平并输入到FCN中,此时FCN便能够将空间数据映射到类别。

上方展示了该FCN(采用softmax激活函数)的输出结果。

模型对比

注:最后三列为定性指标,且为相对比较。

模型 参数量 训练时间 推理时间 准确度
最小二乘法
全连接网络 良好
卷积网络 非常高 优秀

观察结果:

  • 最小二乘法模型速度极快,但泛化能力较弱
  • 卷积网络的参数存储效率非常高
  • 相对于卷积网络的推理时间,最小二乘法和全连接网络都非常快

练习

观察模型对这些输入的响应:

  • 空白画布
  • 中心位置的一个 1
  • 最左侧的一个 1
  • 最右侧的一个 1
  • 中心带有一条线/点的 0
  • 顶部略有断开的 9
  • 轻微旋转的数字
  • 非常细的数字
  • 非常粗的数字

你能找到两个仅相差1个像素、但被分类到不同类别的输入吗?

实现细节

所有三个模型均在您的浏览器中通过纯JavaScript运行;未使用任何框架或软件包。

画布

这块 的画布由一个数字数组支持,数组中包含所显示的透明度值。每次任何像素被更新时,整个画布都会重新绘制。另一个值得注意的细节是我使用的亮度衰减函数:

const plateau = 0.3;
// dist 是到中心距离的平方
const alpha = Math.min(1 - dist / r2 + plateau, 1);
pixels[yc * 28 + xc] = Math.max(pixels[yc * 28 + xc], alpha);

我最初尝试使用 1-dist/r2 作为衰减函数,但它使中心区域褪色过多。因此我添加了 plateau 变量来将函数整体上移,同时用 Math.min 将其限制在 1 以内,确保透明度不会超过 1。这使得笔触看起来更自然。

最小二乘法

这些权重来自我在ECE 174课程中与Piya Pal教授合作完成的一个项目。推理过程仅需45次点积运算和评分。

function evalLSModel(digit, weights) {
    const scores = new Array(10).fill(0);
    for (const pairConfig of weights) {
        const [i, j, w] = pairConfig;
        // 向量点积
        const result = vdot(digit, w);
        if (result > 0) {
            scores[i] += 1;
            scores[j] -= 1;
        } else {
            scores[j] += 1;
            scores[i] -= 1;
        }
    }
    return scores;
}

### 全连接网络

全连接网络推理的主要工作是矩阵点积我以标准方式实现了它

```javascript
function matrixDot(matrix1, matrix2, rows1, cols1, rows2, cols2) {
    // 检查矩阵是否可以相乘
    if (cols1 !== rows2) {
        console.error("Invalid matrix dimensions for dot product");
        return null;
    }

    // 初始化结果矩阵为零
    const result = new Array(rows1 * cols2).fill(0);

    // 执行点积运算
    for (let i = 0; i < rows1; i++) {
        for (let j = 0; j < cols2; j++) {
            for (let k = 0; k < cols1; k++) {
                result[i * cols2 + j] +=
                    matrix1[i * cols1 + k] * matrix2[k * cols2 + j];
            }
        }
    }

    return result;
}

我将矩阵存储在单个一维 Array 中,以获得更好的缓存局部性和更少的堆分配。根据上述公式,推理过程包括 2 次矩阵点积和 2 次激活函数应用。push(1) 调用是为了计算偏置项。

function evalNN(digit, weights) {
    const digitCopy = [...digit];
    digitCopy.push(1);
    // 第一层参数
    const [w1, [rows1, cols1]] = weights[0];
    const out1 = matrixDot(digitCopy, w1, 1, digitCopy.length, rows1, cols1).map(relu);
    const [w2, [rows2, cols2]] = weights[1];
    out1.push(1);
    const out2 = matrixDot(out1, w2, 1, out1.length, rows2, cols2);
    return softmax(out2);
}

卷积网络

这里的卷积网络相当小。在 PyTorch 中为

nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(32, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Dropout(0.5),
    nn.Linear(1600, 10),
    nn.Softmax(dim=1)
)

对于推理,我们只需将前向传播过程移植到 JavaScript 中。Conv2d(包含输入/输出通道)由以下函数实现:

function conv2d(
    nInChan,
    nOutChan,
    inputData,
    inputHeight,
    inputWidth,
    kernel,
    bias,
) {
    if (inputData.length !== inputHeight * inputWidth * nInChan) {
        console.error("输入尺寸无效");
        return;
    }
    if (kernel.length !== 3 * 3 * nInChan * nOutChan) {
        console.error("卷积核尺寸无效");
        return;
    }

    const kernelHeight = 3;
    const kernelWidth = 3;

    // 计算输出维度
    const outputHeight = inputHeight - kernelHeight + 1;
    const outputWidth = inputWidth - kernelWidth + 1;

    const output = new Array(nOutChan * outputHeight * outputWidth).fill(0);

    for (let i = 0; i < outputHeight; i++) {
        for (let j = 0; j < outputWidth; j++) {
            for (let outChan = 0; outChan < nOutChan; outChan++) {
                let sum = 0;
                // 在所有输入通道上对单个位置应用滤波器
                for (let inChan = 0; inChan < nInChan; inChan++) {
                    for (let row = 0; row < 3; row++) {
                        for (let col = 0; col < 3; col++) {
                            const inI =
                                inChan * (inputHeight * inputWidth) +
                                (i + row) * inputWidth +
                                (j + col);

                            const kI =
                                outChan * (nInChan * 3 * 3) +
                                inChan * (3 * 3) +
                                row * 3 +
                                col;
                            sum += inputData[inI] * kernel[kI];
                        }
                    }
                }
                sum += bias[outChan];
                const outI =
                    outChan * (outputHeight * outputWidth) +
                    i * outputWidth +
                    j;
                output[outI] = sum;
            }
        }
    }
    return output;
}

我知道这很丑。我只是放在这里供参考。注意最大池化函数:

function maxPool2d(nInChannels, inputData, inputHeight, inputWidth) {
    if (inputData.length !== inputHeight * inputWidth * nInChannels) {
        console.error("maxpool2d: 输入高度/宽度无效");
        return;
    }
    const poolSize = 2;
    const stride = 2;
    const outputHeight = Math.floor((inputHeight - poolSize) / stride) + 1;
    const outputWidth = Math.floor((inputWidth - poolSize) / stride) + 1;
    const output = new Array(outputHeight * outputWidth * nInChannels).fill(0);

    for (let chan = 0; chan < nInChannels; chan++) {
        for (let i = 0; i < outputHeight; i++) {
            for (let j = 0; j < outputWidth; j++) {
                let m = 0;
                for (let row = 0; row < poolSize; row++) {
                    for (let col = 0; col < poolSize; col++) {
                        const ind =
                            chan * (inputHeight * inputWidth) +
                            (i * stride + row) * inputWidth +
                            (j * stride + col);
                        m = Math.max(m, inputData[ind]);
                    }
                }
                const outI =
                    chan * (outputHeight * outputWidth) + i * outputWidth + j;
                output[outI] = m;
            }
        }
    }
    return output;
}

是的,我之所以要处理那些令人厌恶的索引计算代码,全都是为了那极速 🔥JavaScript🔥 Web 应用的性能。最后,这是将所有部分整合在一起的函数:

function evalConv(digit, weights) {
    const [
        [f1, fshape1], // 卷积滤波器权重
        [b1, bshape1], // 卷积偏置
        [f2, fshape2],
        [b2, fbshape2],
        [w, wshape],   // 全连接层权重
        [b, bshape],   // 全连接层偏置
    ] = weights;

    const x1 = conv2d(1, 32, digit, 28, 28, f1, b1).map(relu);
    const x2 = maxPool2d(32, x1, 26, 26);
    const x3 = conv2d(32, 64, x2, 13, 13, f2, b2).map(relu);
    const x4 = maxPool2d(64, x3, 11, 11);
    const x5 = matrixDot(w, x4, 10, 1600, 1600, 1);
    const x6 = vsum(x5, b);
    const out = softmax(x6);
    return out;
}

总结

希望大家喜欢试用这款应用。如有任何问题或反馈,欢迎在下方留言。

✦ 本文的构思、研究、撰写和编辑均未使用大语言模型。