• Begin
  • Appendix
    • Why Scale by ?
    • Multi-Head Attention
  • Home
  • Posts
  • Notes
  • Books
  • Author
๐Ÿ‡ซ๐Ÿ‡ท fr ๐Ÿ‡จ๐Ÿ‡ณ zh ๐Ÿ‡ฎ๐Ÿ‡ณ ml

Nathaniel Thomas

The Mechanics of Causal Self Attention

November 13, 2024

Begin

Causal self-attention is the mechanism underpinning most of the advances in AI since 2017. In this article, I will step through the computation and hopefully gain a better intuition of how it works.

SelfAttention(Q,K,V)=softmax(mask(dโ€‹QKTโ€‹))V

At a high level, this function takes one sequence and transforms it into another. A sequence is a list of token embeddings, a tensor of shape Lร—d, where L is the input sequence length and d is the embedding dimension. Each row of this matrix corresponds to one input token, which is represented as a d-dimensional vector.

So why then, are there 3 inputs to SelfAttention? This is because, in the Transformer architecture, the input sequence is projected by 3 different dร—d linear layers. If X is the input sequence,

Q=XWQโ€‹,K=XWKโ€‹,V=XWVโ€‹

where WQโ€‹,WKโ€‹,WVโ€‹ are dร—d. So, Q,K,V are simply different representations of the same input sequence.

Let’s compute SelfAttention step-by-step. First, we do QKT, which is a Lร—d by dร—L dot product, resulting in an Lร—L output. What does this do?

QKT=โ€‹q1โ€‹q2โ€‹โ‹ฎqLโ€‹โ€‹โ€‹[k1Tโ€‹โ€‹k2Tโ€‹โ€‹โ‹ฏโ€‹kLTโ€‹โ€‹]=โ€‹q1โ€‹k1Tโ€‹q2โ€‹k1Tโ€‹โ‹ฎqLโ€‹k1Tโ€‹โ€‹q1โ€‹k2Tโ€‹q2โ€‹k2Tโ€‹โ‹ฎqLโ€‹k2Tโ€‹โ€‹โ‹ฏโ‹ฏโ‹ฑโ‹ฏโ€‹q1โ€‹kLTโ€‹q2โ€‹kLTโ€‹โ‹ฎqLโ€‹kLTโ€‹โ€‹โ€‹โ€‹

The result of qiโ€‹kjTโ€‹ is a scalar ( 1ร—d dot dร—1), and it is the vector dot-product between qiโ€‹ and kjโ€‹. If we remember the formula

aโ‹…b=โˆฅaโˆฅโˆฅbโˆฅcosฮธ

we see that the dot-product is positive when ฮธ, the angle between a and b, is close to 0ยบ and negative when the angle is 180ยบ, or when they point in opposite directions. We can interpret the dot product as a similarity metric, where positive values indicate similar vectors, and negative values indicate the opposite.

So our final Lร—L matrix is filled with similarity scores between every pair of q and k tokens. The result is divided by dโ€‹ to prevent the variance from exploding for large embedding dimensions. See Appendix for details.

The next step is to apply the mask function, which sets all values that are not in the lower-triangular section of the input matrix to โˆ’โˆž.

mask(dโ€‹1โ€‹QKT)=dโ€‹1โ€‹โ€‹q1โ€‹k1Tโ€‹q2โ€‹k1Tโ€‹q3โ€‹k1Tโ€‹โ‹ฎqLโ€‹k1Tโ€‹โ€‹โˆ’โˆžq2โ€‹k2Tโ€‹q3โ€‹k2Tโ€‹โ‹ฎqLโ€‹k2Tโ€‹โ€‹โˆ’โˆžโˆ’โˆžq3โ€‹k3Tโ€‹โ‹ฎqLโ€‹k3Tโ€‹โ€‹โ‹ฏโ‹ฏโ‹ฏโ‹ฑโ‹ฏโ€‹โˆ’โˆžโˆ’โˆžโˆ’โˆžโ‹ฎqLโ€‹kLTโ€‹โ€‹โ€‹

To this, we apply softmax, which converts each row of values in the matrix into a probability distribution. The function is defined as a mapping from RLโ†’RL, where the ith output element is given by

softmax(x)iโ€‹=โˆ‘j=1Lโ€‹exjโ€‹exiโ€‹โ€‹forย i=1,2,โ€ฆ,L

Two things to note here:

  1. The sum of all output elements is 1, as is expected for a probability distribution
  2. If an input element xiโ€‹ is โˆ’โˆž, then softmax(x)iโ€‹=0

After applying the softmax function to the masked similarity scores, we obtain:

S=softmax(mask(dโ€‹1โ€‹QKT))=โ€‹S1,1โ€‹S2,1โ€‹S3,1โ€‹โ‹ฎSL,1โ€‹โ€‹0S2,2โ€‹S3,2โ€‹โ‹ฎSL,2โ€‹โ€‹00S3,3โ€‹โ‹ฎSL,3โ€‹โ€‹โ‹ฏโ‹ฏโ‹ฏโ‹ฑโ‹ฏโ€‹000โ‹ฎSL,Lโ€‹โ€‹โ€‹

Where the entries Si,jโ€‹ are defined as:

Si,jโ€‹=โˆ‘k=1Lโ€‹emask(dโ€‹QKTโ€‹)i,kโ€‹emask(dโ€‹QKTโ€‹)i,jโ€‹โ€‹

The resulting matrix S has probability distribution rows of length L. The final step is to map our value matrix V by these probability distributions to give us our new sequence.

SelfAttention(Q,K,V)โ€‹=SV=โ€‹S1,1โ€‹S2,1โ€‹S3,1โ€‹โ‹ฎSL,1โ€‹โ€‹0S2,2โ€‹S3,2โ€‹โ‹ฎSL,2โ€‹โ€‹00S3,3โ€‹โ‹ฎSL,3โ€‹โ€‹โ‹ฏโ‹ฏโ‹ฏโ‹ฑโ‹ฏโ€‹000โ‹ฎSL,Lโ€‹โ€‹โ€‹โ€‹V1โ€‹V2โ€‹V3โ€‹โ‹ฎVLโ€‹โ€‹โ€‹=โ€‹S1,1โ€‹V1โ€‹S2,1โ€‹V1โ€‹+S2,2โ€‹V2โ€‹S3,1โ€‹V1โ€‹+S3,2โ€‹V2โ€‹+S3,3โ€‹V3โ€‹โ‹ฎSL,1โ€‹V1โ€‹+SL,2โ€‹V2โ€‹+โ‹ฏ+SL,Lโ€‹VLโ€‹โ€‹โ€‹โ€‹โ€‹

Note that Si,jโ€‹ is a scalar, and Vkโ€‹ is a 1ร—d embedding vector. Visually, we observe that SelfAttention is selectively combining Value tokens, weighted by a probability distribution generated by how well the queries and keys attend to each other, i.e. have a large inner product. We also see the weight of an output token at index i is dependent only on the input tokens with index โ‰คi, due to the causal mask we applied earlier. This is based on the causal assumption, that the an output token Oiโ€‹ does not depend on future tokens, which is required when training autoregressive (i.e. next token prediction) models.

Hopefully you found this helpful!

Appendix

Why Scale by dโ€‹?

We do this to keep the variance from exploding as d increases.

Assume that qiโ€‹,kiโ€‹โˆผN(ฮผ=0,ฯƒ2=1) and i.i.d. Let’s compute the mean and variance of the unscaled s=qโ‹…k.

The mean is trivially zero:

E[s]=E[i=1โˆ‘dโ€‹qiโ€‹kiโ€‹]=i=1โˆ‘dโ€‹E[qiโ€‹kiโ€‹]=i=1โˆ‘dโ€‹E[qiโ€‹]E[kiโ€‹]=0

And the variance is:

Var(s)=E[s2]โˆ’(E[s])2=E[s2]=d

because

E[s2]=E[i=1โˆ‘dโ€‹j=1โˆ‘dโ€‹qiโ€‹kiโ€‹qjโ€‹kjโ€‹]=i=1โˆ‘dโ€‹j=1โˆ‘dโ€‹E[qiโ€‹kiโ€‹qjโ€‹kjโ€‹]

which is 0 for i๎€ =j (since qiโ€‹,qjโ€‹ and kiโ€‹,kjโ€‹ are i.i.d). For i=j,

i=1โˆ‘dโ€‹E[qi2โ€‹ki2โ€‹]=i=1โˆ‘dโ€‹E[qi2โ€‹]E[ki2โ€‹]=i=1โˆ‘dโ€‹1โ‹…1=d

since E[qi2โ€‹]=E[ki2โ€‹]=ฯƒ2=1.

So if we scale by 1/dโ€‹, our new variance is

Var(dโ€‹sโ€‹)=d1โ€‹Var(s)=1

as desired.

Multi-Head Attention

Most modern systems use multi-head attention, which computes SelfAttention in parallel over several “heads”. We usually let dkโ€‹=dvโ€‹=dmodelโ€‹/H, where H is the number of heads.

Qhโ€‹Khโ€‹Vhโ€‹โ€‹=XWhQโ€‹=XWhKโ€‹=XWhVโ€‹โ€‹WhQโ€‹โˆˆRdmodelโ€‹ร—dkโ€‹WhKโ€‹โˆˆRdmodelโ€‹ร—dkโ€‹WhVโ€‹โˆˆRdmodelโ€‹ร—dvโ€‹โ€‹โ€‹ headhโ€‹=SelfAttention(Qhโ€‹,Khโ€‹,Vhโ€‹)=softmax(mask(dkโ€‹โ€‹Qhโ€‹KhTโ€‹โ€‹))Vhโ€‹ MultiHead(Q,K,V)โ€‹=Concat(head1โ€‹,head2โ€‹,โ€ฆ,headHโ€‹)โ€‹

โ†
An Expertโ€“Level 2048 Bot
Local Approximation
โ†’

back to top