A Large Language Model is a deep Neural Network, trained with self-supervised learning on large amounts of data. It is an AI system, trained on massive amounts of data to understand and generate human-like language.
Examples :
- GPT-4
- ChatGPT
- LLaMa
- Gemini
A LLM is a model that :
- Reads text
- Learns patterns in language
- Predicts the next word
- Generates meaningful response
Concentrate & try to understand the 2 examples mentioned in above image :
- In the first example, we need to find the predicted salary of an employee based on age, experience
- In this case, data is a number, means both the inputs X1(age), X2(experience) are numbers
- It doesn't matter even if we swap x1, x2 in the Input layer while building a Neural Network
- Such networks are called Artificial Neural Networks where order is NOT important
- In the second example, we need to process a string/sentence and these are not numerics
- "Hi, my name is Anil kumar" & "Hi, Anil kumar is name my" are not same !
- Hence order is important in this case, such networks are called Recurrent Neural Networks where order is important
- It deals with sequential data
- ANN (Artificial Neural Network) = Vehicle
- FFNN (Feed Forward Neural Network) = Car
- RNN (Recurrent Neural Network)= Train
- CNN (Convolutional Neural Network) = Drone
- Transformer (Latest/Current) = Jet
Evolution of sequence models from RNN to Transformer :
RNN (1980) --> LSTM (1997) --> GRU (2014) --> Attention Mechanism (Bahdanau - 2014) --> Self Attention --> Transformer (2017) --> LLM's (GPT, BERT etc.)
- ANN is meant for non-sequential data
- RNN strictly meant for sequential data
- Consider the input text mentioned in the above diagram
- Hi my name is Anil (5 words) with o/p or sentiment as 1
- I like coding (3 words) with o/p or sentiment as 0
- I like Indian cricket (4 words) with o/p or sentiment as 1
- If I directly input above data to a ANN, it won't understand anything. Hence we need to handle this type of sequential data using a separate mechanism.
- Total no. of words in the all Sentence's : 12 (this is the size)
- Hi - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- my - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- name - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- is - [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
- Anil - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
Now, if we built a Neural Network based on above input. we need to have 12 input neuron's for all 12 words but remember, each word has 12 bits (as we used one hot encoding) and assume we have 4 hidden neurons. Then what would be the total number of parameters upto 1st hidden layer ?
While processing 1st sentence i.e. Hi my name is Anil ; each word is a 12 dimensional vector.
Total parameters = (12 dimensions * 5 words) * 4 (hidden layers) + 4 = (60 * 4 ) + 4 = 244
244 parameters (weights & biases) only for 1st sentence, until 1st hidden layer!
And these parameter numbers are not consistent across Artificial Neural Network while processing text. Once we design architecture, total number of parameters are fixed. But in this situation, while processing text, total no. of parameters are getting changed and training will take lot of time!
Now think about total parameters in the entire Neural Network ? that too for all sentences ? That's 'n' no of parameters where 'n' is a big number. Don't you think model takes too much time to train ? sometimes forever may be ?
That's why we don't use ANN for processing text.
Welcome to the world of RNN ! 😊
RNN :
A Recurrent Neural Network is a type of neural network designed to handle sequential data where the order of input matter.
Examples :
- Sentences (NLP)
- Speech Signals
- Time series data (stock prices, weather reports)
- Sensor data
Forward & Backward propagation, Mathematics in RNN:
- Assume below movie review comments & respective sentiments(+/-)
- Movie was good - sentiment 1 (+ve)
- Movie was bad - sentiment 0 (-ve)
- Movie was not good - sentiment 0 (-ve)
- Movie, was, good, bad, not ; vocabulary size is 5 (Vocabulary/Corpus)
- Lets use "One hot encoding" here as well , whenever we use one hot encoding, then maximum size of vector is the vocabulary size.
- Now, lets do vector representation using one hot encoding
- Sentence 1 with sentiment 1
- Movie(x11) - [1, 0, 0, 0, 0]
- was(x12) - [0, 1, 0, 0, 0]
- good(x13) - [0, 0, 1, 0, 0]
- Sentence 2 with sentiment 0
- Movie(x21) - [1, 0, 0, 0, 0]
- was(x22) - [0, 1, 0, 0, 0]
- bad(x23) - [0, 0, 1, 0, 0]
- Sentence 3 with sentiment 0
- Movie(x31) - [1, 0, 0, 0, 0]
- was(x32) - [0, 1, 0, 0, 0]
- not(x33) - - [0, 0, 1, 0, 0]
- good(x34) - [0, 0, 0, 1, 0]
- Assume user entered the text, 'Movie' at timestamp t1, 'was' at t2, good at 't3'
- As per above diagram :
- input x11 (Movie) entered at timestamp t1, with weight w(i) --> O(0) is the previous o/p with W(h-1) as weight --> Output is O(1)
- input x12 (was) entered at timestamp t2, with weight w(i) --> O(1) is the previous node output, nothing but previous word with weight w(h) --> Output is O(2)
- input x13(good) entered at timestamp t3, with weight w(i) --> O(2) is the previous node output (nothing but previous 2 words) with weight w(h+1) --> Output is O(3)
- x11, x12, x13 are inputs ; nothing but Movie, was, good
- w(i) is the corresponding weights at given timestamp t1, t2, t3 for 3 words
- O(0), O(1), O(2), O(3) are the previous node outputs with corresponding weights w(h-1), w(h), w(h+1)
- Generally, in ANN, we will pass input data along with randomly generated weights. Isn't it ?
- But her in RNN, along with current input, weight ; we have to consider the previous node output along with it's weight (it represent whether to consider previous node output or not)
- In ANN, forward propagation, f = activation function(w1x1+b)
- In RNN, forward propagation, f = activation function(x11wi + O(0)w(h-1) + b)
- Just same as ANN, but just adding previous node Output & respective weight
- RNN holds previous value as well, otherwise when you type "Movie was good", by the time you enter 'was', it won't remember the previous word 'Movie' ==> This is the whole mantra, read and memorize is n number of times till you remember it forever.
- As per above diagram, we have 3 blocks and lets see the formula for forward propagation of each block :
- FP of Block1, f= activation function(x11wi + O(0)w(h-1) + b) = O(1)
- FP of Block2, f= activation function(x12wi + O(1)w(h) + b) = O(2)
- FP of Block3, f= activation function(x13wi + O(2)w(h+1) + b) = O(3)
- Notice that we are adding the o/p & corresponding weight matrix of previous node as well
- O(0)w(h-1), O(1)w(h), O(2)w(h+1) are the previous state values ; we are storing these values, by internally using a concept called memorization
- Still people may get confuse about w(h-1), w(h), w(h+1). This is nothing but a hidden dimension of output of node at a given timestamp, it learns how strongly each dimension of past memory should influence the current state which means :
- If a weight is large → past information strongly influences current state
- If a weight is small → past information influence is weak
- If negative → it may suppress or invert influence
- So yes — it controls importance, but indirectly through multiplication — not via an explicit priority mechanism.
Forward propagation, ht=f(Wxxt+Whht−1+b)
FP = tanh(W(i)X(it) + O(t-1)W(h-1) + b)
We can see the weights associated for each node in RNN in above diagram. Feed back loop is nothing but the previous state output. To make it simple, it is simply carrying the previous state in addition with current state information as we are processing text. Notice that, even when you type some message in wats-app, you will type first word at one timestamp, second word at another timestamp. Isn't it ? AS IS
RNN meant for Sequence data, where :
- Order matters
- Position matters
- Previous elements influence later elements
Lets see the process of backward propagation :
Assume we have above reviews with corresponding sentiment as shown in the above image.
Vocabulary (unique words in given data) = 3 (cat, mat, rat) where :
- cat - [1, 0, 0]
- mat - [0. 1. 0]
- rat - [0, 0, 1]
- x1 = cat mat rat = [1, 0, 0] [0. 1. 0] [0, 0, 1]
- x2 = rat rat mat = [0, 0, 1] [0, 0, 1] [0. 1. 0]
- x3 = mat mat cat = [0. 1. 0] [0. 1. 0] [1, 0, 0]
As shown in the above diagram, to process first sentence "cat mat rat", we need to pass one word at a time(at a given timestamp) to neural network with one hidden layer and 3 hidden nodes. Please find respective weights and biases in the image.
Below image is the unfolded structure of above diagram.
It is a classification problem, hence using Sigmoid as activation function and using Binary as loss function. Please see the loss function for binary cross entropy as below.
- Wi, Wh-1, Wo are weights and they will initialise randomly
- Derivative of Wi :
- Similarly, derivative of Wh :
Imagine if above values are very small(including learning rate), values, then we will land into Vanishing gradient issue. And if those values are very big(including learning rate), then we will land into Exploding gradient issue. Hence, it is proved that we may land into above issues. We are missing the previous context in RNN.
For example, when you are asking a question to ChatGPT, and if it forgot what you have entered before and remembering only the last word at a given timestamp ! this is the issue.
Mathematical Example : "I like AI"
Step1 : Vocabulary = 3 [I, like, AI]
Step2 : Vector conversion (one hot encoding)
- Word "I" ====> One hot encoding [1, 0, 0]
- Word "like" ====> One hot encoding [0, 1, 0]
- Word "AI" ====> One hot encoding [0, 0, 1]
Step3 : Represent above words in 2 dimensional space, all these are assumed numbers (see below image) : for simplicity, we are considering a 2D vector, we can consider any dimension if we want
- I => Embedding vector = [0.2, 0.5]
- like => Embedding vector = [-0.3, 0.8]
- AI => Embedding vector = [0.6, -0.1]
Step4 : RNN formula
Consider above formula, and considering below random values.Step4 : Calculations (Forward Pass)
FP = tanh([0.7, -0.3]) is equal to :
Observe that we have started with h0 [0, 0] and now h1 = [0.6044, -0.2913], which means we have added a word called 'I' in memory.
Similarly, for other 2 words :
h2 = [0.6652, 0.8472] which is different than h1, value started increasing means adding 2nd word to memory. Now, hidden state consists of 'I', 'Like'
For 3rd word, AI :
Summary :
- At time t(0) --> No word --> No embeddings --> Hidden state value is [ 0, 0]
- At time t(1) --> I --> [0.2, 0.5] --> [0.604, -0.291]
- At time t(2) --> Like --> [-0.3, 0.8] --> [0.6652, 0.847]
- At time t(3) --> AI --> [0.6, -0.1] --> [0.682, 0.269]
Internally, RNN maintains hidden state as above. Alone with current data, it maintains previous sequence of data. But remember, it is giving more importance to current word, and less importance to previous word.
Conclusion :
Whatever we discussed since the starting of this blog, programmatically we can simply write as nn.RNN()
That's it, as simple as it. But if we don't understand what's happening in the background conceptually, then it will really feel like Greek and Latin.
We just started basics of LLMs :
RNN (1980) --> LSTM (1997) --> GRU (2014) --> Attention Mechanism (Bahdanau - 2014) --> Self Attention --> Transformer (2017) --> LLM's (GPT, BERT etc.)
So far, we have discussed :
- RNN architecture
- ANN Vs RNN
- RNN (Forward & Backward propagation for RNN)
- Drawbacks of RNN
- Mathematics involved in RNN
LSTM (Long Short Term Memory): LSTM is a special type of Recurring Neural Network(RNN) designed to handle long term dependencies in sequence data.
RNN Vs LSTM :
Architecture :
LSTM introduces a memory cell (Cₜ) and three gates that control information flow.
- Previous cell state (LTM)
- Current cell state
- Previous Hidden State (STM)
- Current Hidden State
- Input at timestamp t1
- 3 Gates
- Forget Gate
- Input Gate
- Output Gate
- Xt - Current input
1) Forget Gate : Based on the current input and previous short term context, it is going to remove a context in LTM
Lets talk about the internal components of LSTM Architecture :
- Pointwise operators are in circular form in the above architecture (see above diagram)
- Talking about forget gate, X(t), h(t-1) are the inputs to forget gate
- We are providing these values to a "Sigma" function and the output is f(t)
- We are multiplying this output f(t) with C(t-1)
- Assuming below random values : (* is the pointwise operator)
- C(t-1) = [2, 4, 8]
- If f(t) = [0.5, 0.5, 0.5] then C(t-1) * f(t) = [1, 2, 4] --> Reduced to half
- If f(t) = [0, 0, 0] then C(t-1) * f(t) = [0, 0, 0] --> Removed data from LTM
- If f(t) = [1, 1, 1] then C(t-11) * f(t) = [2, 4, 8] --> Using same data
- Observe that after multiplication, i.e. after applying pointwise operator in LTM & output of forget gate, there is a change in the LTM
- Means output of C(t-1) * f(t) changes as per the value of f(t)
- In Forget Gate : We have one σ
- In Input Gate : We have σ & tanh
- In Output gate : We have σ
- Range of sigmoid is (0, 1), means if we get a value < 0.5 then it would be 0, if value > 0.5 then it would be 1
- Main purpose of Input Gate is :
- Base on current input & previous context, it will remove some context from LTM
- if the o/p of f(t) is 0, then everything will be removed from LTM
- if the o/p of f(t) is 1, then nothing will be removed from LTM
- If we don't use σ, and assume if f(t) = 0.5, then only half of the context will be removed from LTM and we are loosing half of the context (Previous long term context information will be removed)
- Output of Sigmoid gate is in the range of (0, 1), same in Input Gate as well, it is not zero centered(always +ve)
- Analogy
- Yes/No, Allow/Don't Allow - This guy decides, should I allow data or not
- Tanh : Range of tanh is (-1, 1)
- Allowing both +ve and -ve
- It is not zero-centered (allow -ve values as well)
- X(t), h(t-1) are inputs to tanh and output of tanh is Candidate cell state i.e. c(t) or c^
- Inputs
- Current Input, x(t)
- Previous Hidden State h(t-1)
- Output
- Delete some content in LTM (in c(t-1))
- x(t) is a 4D vector i.e. [xi1, xi2, xi3, xi4]
- h(t-1) is a 3D vector = c(t-1) (No. of dimensions of h(t-1) is equal to no. of dimensions of c(t-1) is equal to no. of hidden nodes in hidden layer)
- Below is the generalized equation for Forget Gate
- 7 i/p neurons (x(t) is 4D + h(t-1) is 3D)
- 3 hidden neurons (see above thumb rule in red color)
W(f) are the weights of h(t-1) & x(t) in the above Sigmoid equation f(t)
Input Gate :
- Inputs
- X(t), h(t-1)
- Inputs
- tanh(c(t)) & o(t)
Above equations and formulas are only for first word 'I'. Similarly, we have to calculate for "Like" & "AI".
Final summarized table :
Observe that at a given timestamp, how cell state and hidden state values are getting changed! I know that above calculations needed some time to concentrate and understand, but if you are someone who is expecting a deep knowledge on LLMs, then it is worth spending time.
- Sequential processing
- Can't parallelize across sequence length
- Slow training
- Slow inference
- Poor GPU utilization
- Still struggle with very long context
- Though LSTM reduces vanishing gradient :
- Memory is still compressed into a single vector
- Information gets diluted over long sequences
- Example :
- Paragraph with 1000 tokens
- Important information at token 5
- By token 900, signal may weaken
- LSTM ≠ true long term memory
- Its better than vanila RNN but still not ideal
- Slow training
- Hard to scale
Thank you for reading this blog !
Arun Mathe
Comments
Post a Comment