• 开始
    • 为什么要按 进行缩放?
    • 多头注意力机制
  • 首页
  • 文章
  • 笔记
  • 书架
  • 作者
🇺🇸 en 🇫🇷 fr 🇮🇳 ml

Nathaniel Thomas

因果自注意力机制的工作原理

2024年11月13日

开始

因果自注意力机制是自2017年以来推动人工智能进步的核心机制。在本文中,我将逐步解析其计算过程,希望能更好地理解其工作原理。

SelfAttention(Q,K,V)=softmax(mask(d​QKT​))V

从高层次来看,这个函数将一个序列转换为另一个序列。序列是一个由词嵌入组成的列表,形状为 L×d 的张量,其中 L 是输入序列的长度, d 是嵌入维度。矩阵的每一行对应一个输入词元,表示为一个 d 维向量。

那么,为什么 SelfAttention 有三个输入呢?这是因为在 Transformer 架构中,输入序列通过三个不同的 d×d 线性层进行投影。如果 X 是输入序列,

Q=XWQ​,K=XWK​,V=XWV​

其中 WQ​,WK​,WV​ 是 d×d 的矩阵。因此, Q,K,V 只是同一输入序列的不同表示。

让我们一步步计算 SelfAttention。首先,我们计算 QKT,这是一个 L×d 与 d×L 的点积,结果是一个 L×L 的输出。这有什么作用呢?

QKT=​q1​q2​⋮qL​​​[k1T​​k2T​​⋯​kLT​​]=​q1​k1T​q2​k1T​⋮qL​k1T​​q1​k2T​q2​k2T​⋮qL​k2T​​⋯⋯⋱⋯​q1​kLT​q2​kLT​⋮qL​kLT​​​​

qi​kjT​ 的结果是一个标量( 1×d 点积 d×1),它是 qi​ 和 kj​ 之间的向量点积。如果我们记得公式

a⋅b=∥a∥∥b∥cosθ

我们可以看到,当 a 和 b 之间的夹角 θ 接近 0º 时,点积为正;当夹角为 180º 或它们指向相反方向时,点积为负。我们可以将点积解释为相似性度量,其中正值表示向量相似,负值表示相反。

因此,最终的 L×L 矩阵填充了每对 q 和 k 词元之间的相似性分数。结果除以 d​ 是为了防止嵌入维度较大时方差爆炸。详见附录。

下一步是应用 mask 函数,它将输入矩阵中不在下三角部分的所有值设置为 −∞。

mask(d​1​QKT)=d​1​​q1​k1T​q2​k1T​q3​k1T​⋮qL​k1T​​−∞q2​k2T​q3​k2T​⋮qL​k2T​​−∞−∞q3​k3T​⋮qL​k3T​​⋯⋯⋯⋱⋯​−∞−∞−∞⋮qL​kLT​​​

接下来,我们对这个矩阵应用 softmax,它将矩阵中的每一行值转换为概率分布。该函数定义为从 RL→RL 的映射,其中第 i 个输出元素由下式给出:

softmax(x)i​=∑j=1L​exj​exi​​对于 i=1,2,…,L

这里需要注意两点:

  1. 所有输出元素的和为 1,这是概率分布的预期。
  2. 如果输入元素 xi​ 为 −∞,则 softmax(x)i​=0。

在对掩码后的相似性分数应用 softmax 函数后,我们得到:

S=softmax(mask(d​1​QKT))=​S1,1​S2,1​S3,1​⋮SL,1​​0S2,2​S3,2​⋮SL,2​​00S3,3​⋮SL,3​​⋯⋯⋯⋱⋯​000⋮SL,L​​​

其中条目 Si,j​ 定义为:

Si,j​=∑k=1L​emask(d​QKT​)i,k​emask(d​QKT​)i,j​​

生成的矩阵 S 具有长度为 L 的概率分布行。最后一步是通过这些概率分布映射我们的值矩阵 V,从而得到新的序列。

SelfAttention(Q,K,V)​=SV=​S1,1​S2,1​S3,1​⋮SL,1​​0S2,2​S3,2​⋮SL,2​​00S3,3​⋮SL,3​​⋯⋯⋯⋱⋯​000⋮SL,L​​​​V1​V2​V3​⋮VL​​​=​S1,1​V1​S2,1​V1​+S2,2​V2​S3,1​V1​+S3,2​V2​+S3,3​V3​⋮SL,1​V1​+SL,2​V2​+⋯+SL,L​VL​​​​​

注意, Si,j​ 是一个标量,而 Vk​ 是一个 1×d 的嵌入向量。从视觉上看,我们观察到 SelfAttention 选择性地组合 Value 词元,权重由查询和键之间的注意力程度(即较大的内积)生成的概率分布决定。我们还看到,由于之前应用的因果掩码,索引 i 处的输出词元的权重仅依赖于索引 ≤i 的输入词元。这是基于因果假设,即输出词元 Oi​ 不依赖于未来的词元,这在训练自回归(即下一个词元预测)模型时是必需的。

希望这篇文章对你有帮助!

## 附录

为什么要按 d​ 进行缩放?

我们这样做是为了防止方差随着 d 的增加而爆炸。

假设 qi​,ki​∼N(μ=0,σ2=1) 且独立同分布。我们来计算未缩放的 s=q⋅k 的均值和方差。

均值显然为零:

E[s]=E[i=1∑d​qi​ki​]=i=1∑d​E[qi​ki​]=i=1∑d​E[qi​]E[ki​]=0

方差为:

Var(s)=E[s2]−(E[s])2=E[s2]=d

因为

E[s2]=E[i=1∑d​j=1∑d​qi​ki​qj​kj​]=i=1∑d​j=1∑d​E[qi​ki​qj​kj​]

当 i=j 时,该项为 0(因为 qi​,qj​ 和 ki​,kj​ 是独立同分布的)。当 i=j 时,

i=1∑d​E[qi2​ki2​]=i=1∑d​E[qi2​]E[ki2​]=i=1∑d​1⋅1=d

因为 E[qi2​]=E[ki2​]=σ2=1。

因此,如果我们按 1/d​ 进行缩放,新的方差为

Var(d​s​)=d1​Var(s)=1

正如我们所期望的。

多头注意力机制

大多数现代系统使用多头注意力机制,它在多个“头”上并行计算 SelfAttention。我们通常令 dk​=dv​=dmodel​/H,其中 H是头的数量。

Qh​Kh​Vh​​=XWhQ​=XWhK​=XWhV​​WhQ​∈Rdmodel​×dk​WhK​∈Rdmodel​×dk​WhV​∈Rdmodel​×dv​​​ headh​=SelfAttention(Qh​,Kh​,Vh​)=softmax(mask(dk​​Qh​KhT​​))Vh​ MultiHead(Q,K,V)​=Concat(head1​,head2​,…,headH​)​

←
专家级2048游戏机器人
局部近似
→

back to top