Understanding Attention Mechanisms in Transformer Models
Attention is a mechanism that allows transformer models to attend to different parts of the input or output sequences when making predictions. These mechanisms are crucial for the performance of transformer models in tasks such as language translation, text summarization, and sentiment analysis, where the model needs to understand the relationships between different words or phrases in the input and output sequences.
Why is Attention getting a lot of Attention?
Let’s rewind a bit. The below diagram is a simple representation of encoder decoder architecture. Here encoders can be LSTM or RNN units.
A major drawback with this architecture lies in the fact that the encoding step needs to represent the entire input sequence x1, x2, x3, x4 as a single vector c, which can cause information loss as all information needs to be compressed into c. Moreover, the decoder needs to decipher the passed information from this single vector only, a highly complex task in itself indeed.
A potential issue with this encoder–decoder approach is that a neural network needs to be able to compress all the necessary information of a source sentence into a fixed-length vector. This may make it difficult for the neural network to cope with long sentences, especially those that are longer than the sentences in the training corpus.
Below is attention based transformer architecture which overcomes the disadvantages of Vanilla encoder-decoder architecture.
Let us concentrate on attention units in above diagram. There are two types of attention units
- Multi headed attention
- Masked Multi-headed attention
First let us understand the self attention unit (single head attention) which is the basic unit of multi-headed attention.
Self attention: Let us take an Italian translation example “Your cat is a lovely cat” which is a input to a transformer. Output should be il tuo gatto è un gatto adorabile’
Self attention is a mechanism that captures the relation between the words in a sentence. How to calculate the relation?
This is given by the formula:
Here each matrix Q,K,V is a matrix of size (number of words in a sentence * 512) i.e (6*512). This is representation of each word in a input sentence as a vector of length 512. dk represents the size of a vector which is 512
Attention is a matrix and each value in a matrix captures the relation between respective words.
Multi-headed Attention: Multi-headed attention extends the self-attention mechanism by running it multiple times in parallel with different learned linear projections. This allows the model to capture different types of relationships and dependencies at different positions in the input sequence.
Each “head” in multi-headed attention operates independently, learning different relationships and capturing different aspects of the input data. The outputs from all the attention heads are then concatenated and linearly transformed to produce the final attention output.
Linear Projection: The input sequence is linearly transformed into multiple sets of query(Q), key(K), and value vectors(V) for each attention head. These linear projections are learnable and serve as the parameters that the model can adapt during training.
Scaled Dot-Product Attention: Each attention head independently computes attention scores by taking the dot product between the query vectors and key vectors, followed by scaling and applying a softmax function to obtain attention weights. These attention weights determine how much each token attends to the others in the sequence.
Weighted Sum: The attention weights are used to weigh the value vectors, and the weighted sum of these value vectors is computed for each attention head. This step represents the weighted aggregation of information from other tokens based on the learned attention patterns of each head.
Concatenation and Linear Transformation: The output from each attention head is concatenated and passed through a linear transformation to produce the final multi-headed attention output. This step helps the model combine information from different heads and maps it to the desired output dimension.
As we can see from the above diagram each of matrices Q,K,V are multiplied with Wq,Wk,Wv to produce Q’,K’,V’. Wq,Wk,Wv are learnable parameters or weights which get updated while back propagation through the network.
Masked multi-attention head:
This unit is in the decoder section of the transformer.
The masked multi-headed attention mechanism can be expressed as follows:
Given a sequence of input embeddings (X), you apply linear transformations to create three sets of vectors. Query (Q), Key (K), and Value (V). Here the inputs are the target vectors. In this example output will be ‘il tuo gatto è un gatto adorabile’. For each head, you calculate the attention scores by taking the dot product of Q and K, and then scale it by the square root of the dimension of the key vectors. You apply a softmax function to these scores to obtain the attention weights. The attention weights are then used to compute a weighted sum of the Value vectors, producing the output of each head.
Finally, the outputs from all heads are concatenated and linearly transformed to produce the final output for that position in the sequence.
Masking nature: The model has to be casual which means the model should see only the previous words at a certain position in this problem (language translation). All the values above are made to -inf so that we get 0 as value for all upper diagonal elements as shown below.
Encoder-Decoder: The generated K,V keys and values vectors are passed as an input to the decoder. The query(Q) vector is generated by the masked multi attention head of decoder and this along with K,V generated by the encoder are fed as an input to the multi-attention head of decoder.
Advantages of attention mechanism over RNN architecture:
- Reducing Vanishing Gradient Problem: Attention mechanisms help mitigate the vanishing gradient problem that affects the training of deep recurrent neural networks (RNNs). This enables the training of deeper networks, which can capture more complex relationships in data.
- Parallelization: Attention mechanisms enable parallelization of computation, making it easier to train models on modern hardware with multiple GPUs or TPUs. This leads to faster training times and improved scalability
- Capturing Long-Range Dependencies: In traditional neural networks, capturing long-range dependencies in sequences can be challenging. Attention mechanisms enable models to capture relationships between distant elements in a sequence, which is important for understanding context and meaning in natural language.
Despite their successes, transformers have limitations, including high computational requirements, susceptibility to overfitting, and the need for large-scale data for training. Addressing these challenges continues to be a focus of research.
The self-attention mechanism is the core innovation in transformers. It enables models to weigh the importance of different elements in a sequence when making predictions, thereby capturing context effectively. However, it also results in a quadratic computational complexity with respect to sequence length, which can be a limitation for very long sequences.