• 任务
  • 最小二乘法
  • 全连接网络
  • 卷积网络
  • 练习
    • 全连接网络
    • 卷积网络
  • 结论
  • 首页
  • 文章
  • 笔记
  • 书架
  • 作者
🇺🇸 en 🇫🇷 fr 🇮🇳 ml

Nathaniel Thomas

交互式MNIST探索器

在画布上绘制数字,观察AI如何猜测它是什么!

2024年2月20日

Try drawing a digit on the canvas!

1v1 Least Squares (n/a ms)

Fully Connected Network (n/a ms)
Convolutional Network (n/a ms)

在本文中,我们将介绍三种处理MNIST数据集的基本模型,并比较它们的特性、优势和劣势。您可以通过上方的交互界面与每个模型互动,并在条形图中查看它们的输出结果。

任务

对于不熟悉机器学习的人来说,将图像转换为数字可能看起来是一项艰巨的任务。然而,如果我们以下面的方式思考这个问题,事情就会变得简单:

灰度图像只是一个由像素亮度组成的网格,这些亮度是实数值。也就是说,每张图像都是集合 Rw×h 中的某个元素,其中 w 和 h 分别是图像的宽度和高度。因此,如果我们能找到一个函数 f,从 Rw×h 映射到 {0,1,…,8,9},我们就能解决这个问题。

为此,我们使用训练图像 x(n) 和标签 y(n) 构建一个模型。

最小二乘法

该方法涉及为从类别0到9中选出的每一对唯一 (i,j)创建45个从 Rwh→R的线性映射,以推断图像最可能属于 i还是 j。我们可以使用一些线性代数来最小化均方误差(MSE)。首先,我们不再处理 Rw×h中的图像,而是将它们“展平”为 Rwh≡Rn。

将 (i,j)的权重定义为 wij​,这是一个长度为 n+1的向量。为了得到模型的输出,我们计算

y^​ij​=k=1∑n​wij,k​xk​+wij,n+1​

其中 x是一个数字。

我们希望在所有 m个样本上最小化MSE

Lij​​=n1​i=1∑m​(y^​ij(m)​−yij(m)​)2=n1​i=1∑n​([xij⊤​​1​]wij​−yij(n)​)2​

为此,我们创建一个新的矩阵 Xij​,它仅包含属于类别 i或 j的图像以及一列 1用于偏置,以及一个矩阵 yij​,它同样仅包含 {i,j}中的标签,但将 i替换为 −1,将 j替换为 1。

现在,我们的问题简化为

wij​min​∣∣Xij​wij​−yij​∣∣22​

该问题的解由 wij​=Xij†​yij​给出,其中 X†是矩阵 X的伪逆(证明留给读者作为练习😁)。

一旦我们获得了所有 i,j对的 wij​(总共45个),我们就可以表示我们所需的函数 f为

def f(x):
    score = [0] * 10
    for i, j, f_ij in pair_functions:
        out_ij = f_ij(x)
        if out_ij > 0:
            score[i] += 1
            score[j] -= 1
        else:
            score[j] += 1
            score[i] -= 1
    return argmax(score)

每个45个模型都会为其 i或 j“投票”。score数组就是你在上面条形图中看到的内容。

全连接网络

全连接网络(Fully Connected Network,简称FCN)是一个比最小二乘模型大得多的模型。与将标签投影到数据的主子空间不同,我们可以直接学习从输入空间到输出空间的映射。

对于单层网络,我们假设 f可以通过以下方式近似:

f(x)=g(Ax)

其中 g是某个非线性函数。通过梯度下降,可以学习矩阵 A,使得在局部邻域内的误差(分类交叉熵)最小化。在演示中,我们使用了一个两层网络,将图像映射到 R128,然后将结果映射到 R10。这可以表示为:

f(x)=h(B(g(Ax)))

其中我们需要学习矩阵 B∈R10×128和 A∈R128×n。在我们的例子中, g(x)=max(0,x),并且

h(z)i​=∑j=110​ezj​ezi​​

将输出 ∈R10转换为概率分布,如上方的柱状图所示。

卷积网络

上述两种模型的一个局限在于,它们无法像人类那样感知视觉特征。例如,手写的1无论在画布的哪个位置绘制,它都是 1。然而,由于LS和FCN模型没有空间或邻近的概念,它们只会简单地指向最有可能拥有这些确切像素的类别。

这里,我们引入卷积。卷积操作接收一张图像和一个核,将核在图像上滑动,并生成一个输出图像,该图像包含图像像素与核值的加权和。

注意卷积如何编码空间数据,而普通网络则无法做到。由于邻近的像素通常高度相关,我们可以通过最大池化对卷积输出进行下采样,并保留大部分信息。在将图像通过一系列(训练过的)核处理后,我们得到一组矩阵,这些矩阵代表了学习到的空间特征的存在。最后,我们可以将这些矩阵展平并输入到一个FCN中,该FCN现在可以将空间数据映射到类别。

这个FCN(带有softmax激活)的输出如上所示。

## 模型比较

注意:最后三列是定性的,且相互之间具有相对性。

模型 参数数量 训练时间 推理时间 准确性
最小二乘法 35,325 低 快 低
全连接网络 (FCN) 101,760 高 快 良好
卷积网络 (CNN) 34,826 非常高 慢 优秀

观察结果:

  • 最小二乘法模型非常快,但泛化能力较弱
  • CNN 的参数存储效率非常高
  • 相对于 CNN 的推理时间,最小二乘法和 FCN 都非常快

练习

观察模型如何响应以下输入:

  • 空白画布
  • 中心位置的 1
  • 最左侧的 1
  • 最右侧的 1
  • 中心带有一条线/点的 0
  • 顶部略微断开的 9
  • 略微旋转的数字
  • 非常细的数字
  • 非常粗的数字

你能找到两个仅相差 1 像素但映射到不同类别的输入吗?

## 实现细节

所有三个模型都在你的浏览器中以纯 JavaScript 运行;没有使用任何框架或包。

### 画布

这个 28×28 的画布由一个包含显示透明度值的数组支持。每次更新任何像素时,整个画布都会重新绘制。另一个有趣的细节是我使用的亮度衰减函数:

const plateau = 0.3;
// dist 是距离中心点的平方距离
const alpha = Math.min(1 - dist / r2 + plateau, 1);
pixels[yc * 28 + xc] = Math.max(pixels[yc * 28 + xc], alpha);

我最初尝试使用 1-dist/r2 的衰减函数,但它使中心区域过于暗淡。因此,我添加了 plateau 变量,将函数向上平移,但通过 Math.min 将其限制,以确保透明度值不超过 1。这使得笔触看起来更加自然。

### 最小二乘法

我从在ECE 174课程中与Piya Pal教授合作的一个项目中获得了这些权重。推理过程仅仅是45次点积和评分。

function evalLSModel(digit, weights) {
    const scores = new Array(10).fill(0);
    for (const pairConfig of weights) {
        const [i, j, w] = pairConfig;
        // 向量点积
        const result = vdot(digit, w);
        if (result > 0) {
            scores[i] += 1;
            scores[j] -= 1;
        } else {
            scores[j] += 1;
            scores[i] -= 1;
        }
    }
    return scores;
}

全连接网络

全连接网络(FCN)推理的主要工作是矩阵点积,我以标准方式实现了它。

function matrixDot(matrix1, matrix2, rows1, cols1, rows2, cols2) {
    // 检查矩阵是否可以相乘
    if (cols1 !== rows2) {
        console.error("矩阵维度不匹配,无法进行点积运算");
        return null;
    }

    // 初始化结果矩阵为零
    const result = new Array(rows1 * cols2).fill(0);

    // 执行点积运算
    for (let i = 0; i < rows1; i++) {
        for (let j = 0; j < cols2; j++) {
            for (let k = 0; k < cols1; k++) {
                result[i * cols2 + j] +=
                    matrix1[i * cols1 + k] * matrix2[k * cols2 + j];
            }
        }
    }

    return result;
}

为了更好的缓存局部性和减少堆分配,我将矩阵存储在单个一维 Array 中。根据上述公式,推理过程包括 2 次矩阵点积和 2 次激活函数应用。push(1) 调用用于计算偏置。

function evalNN(digit, weights) {
    const digitCopy = [...digit];
    digitCopy.push(1);
    // 第一层参数
    const [w1, [rows1, cols1]] = weights[0];
    const out1 = matrixDot(digitCopy, w1, 1, digitCopy.length, rows1, cols1).map(relu);
    const [w2, [rows2, cols2]] = weights[1];
    out1.push(1);
    const out2 = matrixDot(out1, w2, 1, out1.length, rows2, cols2);
    return softmax(out2);
}

卷积网络

这里的卷积网络非常小。在 Pytorch 中,它是这样的:

nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(32, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Dropout(0.5),
    nn.Linear(1600, 10),
    nn.Softmax(dim=1)
)

对于推理,我们只需要将前向传播部分移植到 JavaScript 中。Conv2d(带有输入/输出通道)由以下代码实现:

function conv2d(
    nInChan,
    nOutChan,
    inputData,
    inputHeight,
    inputWidth,
    kernel,
    bias,
) {
    if (inputData.length !== inputHeight * inputWidth * nInChan) {
        console.error("输入尺寸无效");
        return;
    }
    if (kernel.length !== 3 * 3 * nInChan * nOutChan) {
        console.error("卷积核尺寸无效");
        return;
    }

    const kernelHeight = 3;
    const kernelWidth = 3;

    // 计算输出尺寸
    const outputHeight = inputHeight - kernelHeight + 1;
    const outputWidth = inputWidth - kernelWidth + 1;

    const output = new Array(nOutChan * outputHeight * outputWidth).fill(0);

    for (let i = 0; i < outputHeight; i++) {
        for (let j = 0; j < outputWidth; j++) {
            for (let outChan = 0; outChan < nOutChan; outChan++) {
                let sum = 0;
                // 在所有输入通道上应用滤波器
                for (let inChan = 0; inChan < nInChan; inChan++) {
                    for (let row = 0; row < 3; row++) {
                        for (let col = 0; col < 3; col++) {
                            const inI =
                                inChan * (inputHeight * inputWidth) +
                                (i + row) * inputWidth +
                                (j + col);

                            const kI =
                                outChan * (nInChan * 3 * 3) +
                                inChan * (3 * 3) +
                                row * 3 +
                                col;
                            sum += inputData[inI] * kernel[kI];
                        }
                    }
                }
                sum += bias[outChan];
                const outI =
                    outChan * (outputHeight * outputWidth) +
                    i * outputWidth +
                    j;
                output[outI] = sum;
            }
        }
    }
    return output;
}

我知道这代码很丑。我只是把它放在这里供参考。接下来是 maxpool 的代码:

function maxPool2d(nInChannels, inputData, inputHeight, inputWidth) {
    if (inputData.length !== inputHeight * inputWidth * nInChannels) {
        console.error("maxpool2d: 输入高度/宽度无效");
        return;
    }
    const poolSize = 2;
    const stride = 2;
    const outputHeight = Math.floor((inputHeight - poolSize) / stride) + 1;
    const outputWidth = Math.floor((inputWidth - poolSize) / stride) + 1;
    const output = new Array(outputHeight * outputWidth * nInChannels).fill(0);

    for (let chan = 0; chan < nInChannels; chan++) {
        for (let i = 0; i < outputHeight; i++) {
            for (let j = 0; j < outputWidth; j++) {
                let m = 0;
                for (let row = 0; row < poolSize; row++) {
                    for (let col = 0; col < poolSize; col++) {
                        const ind =
                            chan * (inputHeight * inputWidth) +
                            (i * stride + row) * inputWidth +
                            (j * stride + col);
                        m = Math.max(m, inputData[ind]);
                    }
                }
                const outI =
                    chan * (outputHeight * outputWidth) + i * outputWidth + j;
                output[outI] = m;
            }
        }
    }
    return output;
}

是的,我处理这些令人头疼的索引计算代码的唯一原因是为了那极速 🔥JavaScript🔥 Web 应用的性能。最后,这是将所有部分整合在一起的函数:

function evalConv(digit, weights) {
    const [
        [f1, fshape1], // 卷积滤波器权重
        [b1, bshape1], // 卷积偏置
        [f2, fshape2],
        [b2, fbshape2],
        [w, wshape],   // 全连接层权重
        [b, bshape],   // 全连接层偏置
    ] = weights;

    const x1 = conv2d(1, 32, digit, 28, 28, f1, b1).map(relu);
    const x2 = maxPool2d(32, x1, 26, 26);
    const x3 = conv2d(32, 64, x2, 13, 13, f2, b2).map(relu);
    const x4 = maxPool2d(64, x3, 11, 11);
    const x5 = matrixDot(w, x4, 10, 1600, 1600, 1);
    const x6 = vsum(x5, b);
    const out = softmax(x6);
    return out;
}

结论

希望大家喜欢使用这个应用。如果有任何问题或反馈,欢迎在下方留言。


←
Python 性能优化:不再随意
专家级2048游戏机器人
→

back to top