Attention Mechanism is the heart of LLM. Around 70% of transformer architecture is about Attention Mechanism.
Four categories of Attention Mechanism :
- Simplified Attention Mechanism
- Attention Mechanism with trainable weights
- Causal Attention Mechanism
- Multi Head Attention Mechanism
Note :
If we will be familiar with above attention mechanisms, then it is not just about GPT model, we will be familiar with other LLM architectures like Deep Seek R1, R2 etc. It is extremely important to have commanding knowledge on one LLM framework(in our case GPT), to understand the changes in other models.Background about Attention Mechanism :
Before discussing about Attention Mechanism, first lets discuss what are the problems for implementing this concept.
RNN (Recurrent Neural Network) - We have discussed about RNN, it introduces very important concept called MEMORY, which is the hidden state in a RNN. It will maintain previous context information.
Each RNN cell receive an input x1, x2, x3 at respective timestamp t1, t2, t3 and each cell maintain a memory called Hidden State which will be passed into next RNN cell along with respective input, and now current RNN cell(2nd cell in the above image) also maintain a Hidden State which consists of current context and also previous context(from hidden state of previous cell). Similarly, context will be carry forward until last cell.
But the problem with RNN is, for small context it will remember entire context but for longer sentences, chances are there where RNN may loose the context. It has Vanishing Gradient issue. It will give more importance/priority to latest data but less importance/priority to old data.
Carefully observe below example :
The Cat that was sitting on the mat, which was next to the Dog, jumped!
Can you tell 'jumped' represented Cat or Dog ? Human beings call tell it, but a LLM model identifies that it is related to Cat ? (instead of Dog).
That's why LLM needs "Attention Mechanism" !
And CAT is 2nd word in the sentence, RNN might loose this word in the sentence if the sentence is too long(as RNN will give less priority to previous tokens). If LLM is not maintaining this whole context, then how it will understand the complete context.
FYI, so far we have covered below lifecycle as part of understanding LLM :
RNN --> LSTM --> Attention Mechanism --> what's next ?
Lets see How Attention Mechanism is going to help.
Sentence : "Your journey start with one step" (6 tokens)
We all know that TOKEN EMBEDDINGS will represent the semantic meaning of the same token in a sentence. Token Embedding vector belongs to same token. Incase, if LLM has a mechanism to pay attention to other tokens as well (in addition with same token) then it will be helpful to maintain the meaning of entire sentence.
Example : Consider 5 students as shown in below image with respective attributes like height, weight etc. and these are nothing but token embeddings. How would we know that person x1 is close to x2, x3, x4, x5 ? Attributes must be closer ? Isn't it ? So, we need to pay attention to all the students.
Similarly, for "Your journey start with one step" : if we know the SCORES of other tokens, then it would be easy to predict the next token. Means, a token should pay attention to the other tokens, it should check the relationship between other tokens.
Even when you type some text in the watsapp, it will recommend some words right ? Because it identified similarity by comparing the features of other tokens with the token that you enter.
Let me put it in a simple way: In addition with paying attention to the current token, we need to pay attention to other tokens in the sentence as well, to understand the full meaning of the current token. This is called Attention Mechanism.
For example : When are processing the token 'journey' in the sentence "Your journey start with one step", along with the current token JOURNEY, we need to pay attention to other tokens in the entire sentence, to understand how close all others tokens with current token to maintain the meaning of entire sentence. This will happen after Input Embeddings. One token should pay attention to other tokens.
As shown in above image, when we try to translate the sentence from one language to another (German to English in the above image), it's not possible to merely translate word by word. Instead, the translation process requires contextual understanding and grammatical alignment. In the above image, observe that few tokens have access to other tokens to create meaningful context (2nd part of above image).
If we implement above approach without attention mechanism in LLM, then imagine the output ?we won't get expected output right? That's why Attention Mechanism is very important.
To address above problem, it is common to use a deep neural network with two submodules, an encoder and a decoder. The job of the encoder is to first read and process the entire text, and the decoder then produces the translated text. Before transformers, Recurrent Neural Networks(RNN) were the most popular encoder-decoder architecture for language processing.
In the above image, observe that when input data is transiting from Encoder to Decode, it is carrying the hidden states of all the 3 RNN cells from Encoder. That's why we don't loose context if we use Attention Mechanism.
Capturing data dependencies with Attention Mechanism
Although RNN works fine for short sentences, they don't work well for longer texts as they don't have direct access to previous words in the input. One major shortcoming in this approach is that RNN must remember the entire encoded input in a single hidden state before passing it to Decoder.- Observe that Decoder model has access to all the hidden states from encoder which was enabled attention mechanism called Bahdanau attention mechanism.
- Also, in the encoder section, dot represents the weightage of each input word/token with respect to other input tokens. Based on that weightage, it will pay more attention to the token with more/heavy weightage, compared with other tokens. These weights are random numbers initially, but accurate weights will be generated as part of Neural Network training. We will programatically see how it work.
- Without attention mechanism, it can access the hidden state of ONLY last RNN cell
- There is a problem in this design as well, which is sequential execution, cell after cell in RNN. If input data is too large, and if we deploy this model, it will process one word at a time and process time will take a hit.
- Bahdanau attention mechanism introduced in 2014 which introduced attention mechanism
- Attention is all you need paper introduced in 2017, which introduced self attention and Transformer Architecture.
1) Simplified Self Attention Mechanism
- Calculate Attention Scores
- Calculate Attention Weights
- Calculate Context vectors
A simple self-attention mechanism without trainable weights
Lets begin by implementing variant of self-attention, free from any trainable weights. Goal is to understand few concepts in self-attention before adding trainable weights.
Input sentence : "Your journey starts with one step"
(Note : assuming the pre processing is completed - means token embeddings, positional embeddings, inout embeddings have been calculated)
Token1(Your), Token2(journey), Token3(starts), Token4(with), Token5(one), Token6(step) and assuming it is a 3 dimensional vector(Input Embeddings are 3D)
import torch
# Assuming below tensor is have Input Embeddings for same sentence. inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6) )
# Below graph represents above input embeddings vector in 3D representation
If you observe carefully, journey & starts are close-by in above 3D space but one is far away from journey & starts. This is the main intention of Attention Mechanism, to carry some meaning to data.Token1(Your), with index 0
Token2(journey), with index 1
Token3(starts), with index 2
Token4(with), with index 3
Token5(one), with index 4
Token6(step), with index 5
d_model = 3 (3D model)
As per above image, while processing the Attention score for word "journey", we need to represent how much attention we need pay for other words with respect to journey to know the meaning of having "journey" in the sentence.
Goal : Ideally, we have to create the context vector for all the words in the sentence, but for simplicity lets calculate the context vector of one word i.e. journey, which will carry the meaning of how journey is related to other words.
x_2 = journey
z_2 = context vector of journey
Steps to follow : Attention scores --> Attention Weights --> Context Vector
Step1 : Calculate Attention Scores
x_2 dot (inputs) where x_2 is journey, inputs are "Your journey starts with one step", and applying dot product because it is a fundamental way we can combine two vectors and it tells us about how much two vectors point in the same direction.
Formula for dot product of two vectors a, b is a . b = |a| |b| Cosθ (Cos 0 = 1, Cos 90 = 0)
Important point to understand :
Above cross product represents/show that whether 2 vectors are close by or far away to each other. If we know this information, then we will decide on how much attention we need to pay on a particular token represented in a vector format.That's why we have to apply a dot product in between our word "journey" with respect to all other vectors.
Code :
import torch
# Assuming below tensor is have Input Embeddings for same sentence.
inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6) )
query = inputs[1] # 2nd input token is the query (journey) attn_scores_2 = torch.empty(inputs.shape[0]) #Shape of above input embedding 6 * 3 for i, x_i in enumerate(inputs): attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors) print(attn_scores_2)
Note : query contains the vector values of word "journey", enumerate() is adding index for each word in the sentence and then we are applying dot product in between word "query" and all other words in the sentence.
That's how we create Attention Scores. Now, we need to normalise these Attention Scores so that interpretability will be easy, which means I should be able to tell that the word "Journey" attention weight is 20% in one word, 50% on another word. If we sum, that should be 100% for all the words in the sentence.
Step2 : Calculate Attention Weights
- Nothing but normalizing the Attention Scores.
Here's a straightforward method for achieving this normalization step:
Code :
Output :
Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656]) Sum: tensor(1.0000)
- In practice, its more common and advisable to use softmax function for normalization.
- Once Attention Weights are calculated, multiply them with all the input values and sum all, then you get Context vector.
Finally, after generating CONTEXT VECTOR, the word journey will have the context of other words as well and this is carried forward to next steps.
What we have discussed so far :
Using Input Embeddings, calculate attention scores, then calculate attention weights, finally calculate attention vectors.
We can see the formula for softmax to normalize the data in the above image. This is to handle outliers in the given data.
- dim=0 mens apply softmax to row
- This is PyTorch way of Normalization
- Note we are are applying softmax while calculating attention weights
Now, we can extend this computation to calculate attention weights and context vectors for all inputs. First we add a additional for-loop to compute the dot products for all pairs of inputs.
Important concept to understand before extending this calculation for all inputs :
- We know that, to calculate the Attention Scores of a particular word/token, we need to multiply the Input Embedding vector values of that token/word with all other tokens. Correct ?
- Now, to simplify these calculations, especially when we want to calculate Attention Scores for all input tokens, we need to multiple the Input Embeddings vector matrix to the transpose of itself BECAUSE while multiplying 2 matrices, we do 1st row in matrix1 with 1st column in matrix2 and if we don't transpose, then we end up with inaccurate values of Attention scores.
Output :
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310], [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865], [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605], [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565], [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935], [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
Then we normalize each row so that the values in each row sum to 1.
Code : (dim = -1 is reverse indexing, we are saying to consider column for normalization)
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452], [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581], [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565], [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720], [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295], [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Once Attention Scores are ready, we need to calculate Attention Weights.
Code :
Output :
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452], [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581], [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565], [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720], [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295], [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Lets briefly verify that the rows indeed all sum to 1.
Row 2 sum: 1.0 All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
Output :
tensor([[0.4421, 0.5931, 0.5790], [0.4419, 0.6515, 0.5683], [0.4431, 0.6496, 0.5671], [0.4304, 0.6298, 0.5510], [0.4671, 0.5910, 0.5266], [0.4177, 0.6503, 0.5645]])
Self-Attention with Trainable weights :
Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most popular LLMs. This self-attention mechanism is also called as Scaled dot-product attention.
We need to introduce 3 new trainable weight matrices i.e.
- Weight matrix of Query
- Weight matrix of Key
- Weight matrix of Value
- Consider that, we have a Input Vector with 6 words/tokens
- Connected to 3 different vectors, those are Query, Key, Value vectors
- Assuming above dimensions as well for easy understanding
As shown in above image, if we multiply Input Vector to Query, Key, Value vectors we will get 3 different vectors with dimensions 6 * 2. Initially those Query, Key, Value vectors are random values but as iterations progress in Neural Network, model will come up with accurate vector values(note that these are the weights of connections as we see in a NN, hence it will go into training, after multiple epochs, accurate Query, Key, Value vector values will outcome).
Once our Query matrix is ready, we are simply multiplying this Query vector with all the inputs in the Input vector to calculate the attention scores. This input vector is nothing but Keys matrix.
Here we start computing only one context vector, for input "journey". We will then modify this code to calculate all context vectors.
Code :
FYI, code is available for self-attention mechanism with trainable weights in the following git hub location : https://github.com/amathe1/LLMs/blob/main/Attention_Mechanism.ipynb
It has complete documentation about each line of code and self explanatory.
Why we are using Query, Key, Value matrices ?
- First thing to remember is we are talking about Self-attention mechanism with trainable weights and these are the trainable weights. Then immediate question would be, why 3 matrices instead of 1 matrix
Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called as Scaled dot-product attention.
Please check the implementation and code from below file : https://github.com/amathe1/LLMs/blob/main/Attention_Mechanism.ipynb
- Verify the scalability and why we are applying SQRT to gain stability in learning before applying softmax function
- We mentioned 2 reasons in above code, please verify both.
- WHY DIVIDE BY SQRT
- BUT WHY SQRT
- This is to minimize the spread of Variance during calculating attention weights
Finally, context vector :
context_vec_2 = attn_weights_2 @ values print(context_vec_2)
Steps involved in self-attention with trainable weights :
- Input values
- Initialize Wq, Wk, Wv (trainable aprameters)
- Calculate Q, K, V
- Q = Input * Wq
- K = Input * Wk
- V = Input * Wv
- Calculate Attention Scores
- Attention Score = Query matrix * K^T (Transpose of Keys matrix)
- Scaled AS = AS /SQRT(dimension_k)
- Normalization of AW = Softmax(Scaled AS)
- Calculate Context Vector = Attention Weights * Values vector (Wv)
Below is the compacted self attention code :
import torch.nn as nn class SelfAttention_v1(nn.Module): def __init__(self, d_in, d_out): super().__init__() self.W_query = nn.Parameter(torch.rand(d_in, d_out)) self.W_key = nn.Parameter(torch.rand(d_in, d_out)) self.W_value = nn.Parameter(torch.rand(d_in, d_out)) def forward(self, x): keys = x @ self.W_key queries = x @ self.W_query values = x @ self.W_value attn_scores = queries @ keys.T # omega attn_weights = torch.softmax( attn_scores / keys.shape[-1]**0.5, dim=-1 ) context_vec = attn_weights @ values return context_vec
With this explanation, we are done with 1st & 2nd types of attention mechanism
- Self-attention mechanism
- Self-attention mechanism with trainable weights
Causal Attention Mechanism (Masked Attention Mechanism)
Causal Attention Mechanism is inherited from Self-Attention Mechanism with trainable weights, and that's the reason we have started our attention mechanism journey from Self-Attention mechanism without trainable weights followed by with trainable weights. This is the actual Attention Mechanism used in the Transformer architecture.
This is nothing but Hiding the future words during the process of implementing Attention Mechanism.
For many LLM tasks, you will want the self-attention mechanism to consider only the tokens that appear prior to the current position when predicting the next token in a sequence. Causal attention (or) Masked attention, is a specialized form of self-attention. It restricts model only to consider previous and current inputs in a sequence when processing any given token when computing attention scores. This is contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once. Please see above image.
Causal Attention is going to calculate the context vectors of input sequence "Your journey starts with one step" as below.
- Your
- Your journey
- Your journey starts
- Your journey starts with
- Your journey starts with one
- Your journey starts with one step
A context vector will be created for each of the above lines. GPT is a decoder model, which will predict the next token in a sequence-to-sequence type. Context vector is holding the relationship between all the tokens in the sequence. Means what is the next token which is near to me.
Applying a Causal Attention Mask
Note : One way to obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.
Steps for Causal Attention (Approach-1) :
- Attention scores (Query * Key^T )
- Attention Weights = Normalize (Attention Scores) = (Scaling + Softmax)
- Masking
- Normalize again
Please see code under "Hiding Future words with causal attention" in the following code at location : https://github.com/amathe1/LLMs/blob/main/Attention_Mechanism.ipynb
Just FYI, below image show how a Query matrix is getting created in code.
Query = Self.W_query(x), where x is input embedding vector with tokens of all input sentence
= x * W^T (Transpose of query vector) + bias (we are ignoring in this case)
= x * W^T (Transpose of query vector)
= (6 * 3 matrix) * (2 * 3 matrix) (before transpose)
= (6 * 3 matrix) * (3 * 2 matrix) (after transpose)
= 6 * 2 matrix
Code that we are using during Self-Attention :
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
Instance creation :
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out) print(sa_v2(inputs))
Output :
tensor([[-0.0739, 0.0713],
[-0.0748, 0.0703],
[-0.0749, 0.0702],
[-0.0760, 0.0685],
[-0.0763, 0.0679],
[-0.0754, 0.0693]], grad_fn=<MmBackward0>)- 1st row represent the context vector of YOUR
- 2nd row represent the context vector of JOURNEY
- 3rd row represent the context vector of starts
Code related to Causal Attention :
import torch inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6) )
queries = sa_v2.W_query(inputs) #A keys = sa_v2.W_key(inputs) attn_scores = queries @ keys.T attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1) print(attn_weights)
O/p :
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
[0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
[0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
[0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
[0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<SoftmaxBackward0>)# Shape of attn_scores.shape is 6 * 6
# Hence context_length = 6
context_length = attn_scores.shape[0] torch.ones(context_length, context_length)
O/p :
tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]])# Masking
context_length = attn_scores.shape[0] mask_simple = torch.tril(torch.ones(context_length, context_length)) print(mask_simple)
O/p :
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])Now, we can multiply this mask with the attention weights to zero out the values above the diagonal:masked_simple = attn_weights*mask_simple
print(masked_simple)
O/p :tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
[0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
[0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<MulBackward0>)
As we can see, the elements above the diagonal are successfully zeroed out
The third step is to renormalize the attention weights to sum up to 1 again in each row.
We can achieve this by dividing each element in each row by the sum in each row:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
O/p :
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
[0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
[0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<DivBackward0>)
The result is an attention weight matrix where the attention weights above the diagonal are zeroed out and where the rows sum to 1.
Steps for Causal Attention (Approach-2) :
A more efficient way to obtain the masked attention weight matrix in causal attention is to mask the attention scores with negative infinity values before applying the softmax function.
The softmax function converts its inputs into a probability distribution.
When negative infinity values (-∞) are present in a row, the softmax function treats them as zero probability.
(Mathematically, this is because e -∞ approaches 0.)
We can implement this more efficient masking "trick" by creating a mask with 1's above the diagonal and then replacing these 1's with negative infinity (-inf) values:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) masked = attn_scores.masked_fill(mask.bool(), -torch.inf) print(masked)
O/p :
tensor([[0.2899, -inf, -inf, -inf, -inf, -inf],
[0.4656, 0.1723, -inf, -inf, -inf, -inf],
[0.4594, 0.1703, 0.1731, -inf, -inf, -inf],
[0.2642, 0.1024, 0.1036, 0.0186, -inf, -inf],
[0.2183, 0.0874, 0.0882, 0.0177, 0.0786, -inf],
[0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
grad_fn=<MaskedFillBackward0>)Now, all we need to do is apply the softmax function to these masked results, and we are done.
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)
O/p : tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
[0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
[0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<SoftmaxBackward0>)As we can see based on the output, the values in each row sum to 1, and no further normalization is necessary.
Masking in Transformers sets scores for future tokens to a large negative value, making their influence in the softmax calculation effectively zero.
The softmax function then recalculates attention weights only among the unmasked tokens.
This process ensures no information leakage from masked tokens, focusing the model solely on the intended data.
That's all, this is what a Causal Attention mean.
Note : Additionally we can implement Dropout. It is not mandatory.
torch.manual_seed(123) dropout = torch.nn.Dropout(0.5) #A example = torch.ones(6, 6) #B print(dropout(example))
O/p :
tensor([[2., 2., 2., 2., 2., 2.],
[0., 2., 0., 0., 0., 0.],
[0., 0., 2., 0., 2., 0.],
[2., 2., 0., 0., 0., 2.],
[2., 0., 0., 0., 0., 2.],
[0., 2., 0., 0., 0., 0.]])
When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero.
To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 =2.
This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.
Now, let's apply dropout to the attention weight matrix itself:
torch.manual_seed(123) print(dropout(attn_weights))
As we can see above, the resulting attention weight matrix now has additional elements zeroed out and the remaining ones rescaled.
Having gained an understanding of causal attention and dropout masking, we will develop a concise Python class in the following section.
This class is designed to facilitate the efficient application of these two techniques.
Recap of Causal Attention :
- Input values (Input Embeddings - this step will inherit from Self-attention)
- Wq, Wk, Wv matrix (this step will inherit from Self-attention)
- Query = Input @ Wq (Transpose of Wq)
- Key = Input @ Wk (This is not Transpose)
- Value = Input @ Wv
- Calculate the Attention Score (Query matrix * Transpose of Keys matrix)
- Apply Masking on the top of Attention Score (Fill upper triangle matrix with -infinity)
- Feature of Causal Attention : Only previous, current tokens will be visible
- Calculate Attention Weights = Softmax(Masked Attention scores / Scaling)
- Optional Dropout on Attention Weights (to prevent overfitting problem)
- To prevent overfitting and it is optional
- Calculate Context vector = Dropout(Attention Weights) * Value matrix
As we know, we are considering only one input sentence until now i.e. "Your journey starts with one step".
Assume we have 2 sentences :
Sentence/Batch 1 : "Your journey starts with one step" (6 tokens)
Sentence/Batch 2 : "Your journey starts with one step" (6 tokens)
O/p :
torch.Size([2, 6, 3])
- 2 represents no of batches
- 6 represents number of tokens in each batch
- 3 represents each token is represented in a 3 dimensional space/vector
Topic in Github code : Implementing a compact causal attention class
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) # New
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape # New batch dimension b
keys = self.W_key(x) # x multiply with transpose of Weight of key matrix
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights) # New
context_vec = attn_weights @ values
return context_vec
The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here.
For instance, when we use the CausalAttention class in our LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training the LLM in future chapters.
This means we don't need to manually ensure these tensors are on the same device as your model parameters, avoiding device mismatch errors.
We can use the CausalAttention class as follows, similar to SelfAttention previously:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
O/p :
context_vecs.shape: torch.Size([2, 6, 2])
As we can see, the resulting context vector is a 3D tensor where each token is now represented by a 2D embedding:
Conclusion :
We are done with Causal Attention mechanism. We will talk about Multi-headed Causal attention from now.
Multi-headed attention
The number of attention heads depends on the model size/variant. Below are commonly referenced base models.
Attention Heads = (Hidden Dimension)/64
More heads allow the model to learn different relationships in parallel, for example :
- On head - grammar
- Another head - long-distance dependencies
- Another - entity relationships
- another - positional context
Ex : Cat sitting on the mat is Hungry
For above example, different heads may focus on :
- subject-verb declaration
- cat hungry
- position in sentence
- semantic meaning
Each head operates independently. This can be achieved by stacking multiple causal-attention modules.
Same as Causal attention, all the steps are same :
- Input values (Input Embeddings - this step will inherit from Self-attention)
- Wq, Wk, Wv matrix (this step will inherit from Self-attention)
- Query = Input @ Wq (Transpose of Wq)
- Key = Input @ Wk (This is not Transpose)
- Value = Input @ Wv
- Calculate the Attention Score (Query matrix * Transpose of Keys matrix)
- Apply Masking on the top of Attention Score (Fill upper triangle matrix with -infinity)
- Feature of Causal Attention : Only previous, current tokens will be visible
- Calculate Attention Weights = Softmax(Masked Attention scores / Scaling)
- Optional Dropout on Attention Weights (to prevent overfitting problem)
- To prevent overfitting and it is optional
- Calculate Context vector = Dropout(Attention Weights) * Value matrix
Only different is, we are having multiple heads so that we will have multiple weight, key, value matrices and finally we will get multiple context vectors and add them.
Approach-1 :
In practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights, and then combining their outputs
In code, we can achieve this by implementing a simple MultiHeadAttentionWrapper class that stacks multiple instances of our previously implemented CausalAttention module:
Code :
class MultiHeadAttentionWrapper(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() self.heads = nn.ModuleList( [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)] ) def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=-1)
Now, with above approach we need to create instances for Causal Attention class 'n' no of times depends on the model and the number of heads it is using. Also, we are using for loop and for loops are always time consuming loops which impact processing speed and performance. Hence we have another approach, which is Approach-2.
Instead of initiating 'n' no of weight matrices at start, take one weight matrix, split this matrix into Query, Key, Value matrices after calculations are completed(see below image).
If model is using 96 heads, then Input Embedding vector needs to be multiplied with 96 Query, Key, Value matrices. Lot of math, instead use only one Q, K, V matrix and calculate one matrix 96 * 1 time, and then divide output matrix into Q, K, V matrix.Approach-2 : Multi-headed attention with weight splits
- Step-1
- Start with your input
- Input is "I Like AI"
- batch_size, num_tokens, d_in = [1, 3, 6]
- batch_size = 1, num_tokens = 3, d_in = 6 (each token in 6 dimension)
- Step-2
- Wq, Wk, Wv = 6 * 6 (d_in, d_out)
- Assuming input, output dimension = 6
- Step-3
- Calculate Q, K, V
- Q = Input @ Wq Transpose
- K = Input @ Wk
- V = Input @ Wv
- Input = (1, 3, 6), Wqkv (6, 6) then after multiplication
- Q = (1, 3,6)
- K = (1, 3,6)
- V = (1, 3,6)
- Step-4
- Assuming num_heads = 2, d_out = 6
- head_dim = d_out/num_heads = 6/2 = 3
- So, dimension of head = 3 (for each head)
- Step-5
- In step-3, (1, 3, 6) are batch_size, no of tokens & output dimensions
- In step-4, head_dim = d_out/num_heads, therefore : d_out = head_dim * num_heads
- Replacing output dimensions formula in 1st line as below
- So, output is (1, 3, num_heads, head_dim) = (1, 3, 2, 6) - its for 2 heads
- keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
- values = values.view(b, num_tokens, self.num_heads, self.head_dim)
- queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
- So, (1, 3, 2, 3) is the matrix for each head (4 dimensional space)
- Step-6
- Grouping the data based on heads (unlike based on tokens as above)
- (1, 3, 2, 3) => (batch, no. of tokens, heads, heaad_dimension) but we need group data based on heads, so swap index 1, 2 then (1, 2, 3, 3)
- keys = keys.transpose(1, 2)
- queries = queries.transpose(1, 2)
- values = values.transpose(1, 2)
- Step-7
- Calculate Attention scores
- AS = Query @ Keys Transpose = (1, 2, 3, 3) @ (1, 2, 3, 3)
- Step-8
- Calculate Attention Weights
- Masking(Attention Score) then
- Scaling (Masked Values) then
- Softmax
# Original mask truncated to the number of tokens and converted to boolean mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # Use the mask to fill attention scores attn_scores.masked_fill_(mask_bool, -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
- Step-9
- Apply Dropout on Attention weights
- attn_weights = self.dropout(attn_weights)
- Step-10
- Calculate Context Vector
- Step-11
- Concat Context Vectors
# Combine heads, where self.d_out = self.num_heads * self.head_dim context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) context_vec = self.out_proj(context_vec) # optional projection
Code :
class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() assert (d_out % num_heads == 0), \ "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs self.dropout = nn.Dropout(dropout) self.register_buffer( "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self, x): b, num_tokens, d_in = x.shape keys = self.W_key(x) # Shape: (b, num_tokens, d_out) queries = self.W_query(x) values = self.W_value(x) # We implicitly split the matrix by adding a `num_heads` dimension # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) keys = keys.transpose(1, 2) queries = queries.transpose(1, 2) values = values.transpose(1, 2) # Compute scaled dot-product attention (aka self-attention) with a causal mask attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head # Original mask truncated to the number of tokens and converted to boolean mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # Use the mask to fill attention scores attn_scores.masked_fill_(mask_bool, -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = self.dropout(attn_weights) # Shape: (b, num_tokens, num_heads, head_dim) context_vec = (attn_weights @ values).transpose(1, 2) # Combine heads, where self.d_out = self.num_heads * self.head_dim context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) context_vec = self.out_proj(context_vec) # optional projection return context_vec
Please feel free to download entire Attention Mechanism code from following GitHub location : https://github.com/amathe1/LLMs/blob/main/Attention_Mechanism.ipynb
That's all for this blog. We are done with Attention Mechanism used in LLM.
Book Reference : Below book is a god reference to learn LLMs
Build a Large Language Model by Sebastian Raschkka
Thank you for reading this blog !
Arun Mathe
Comments
Post a Comment