Attention in Transformers: concepts and code in PyTorch.

The main idea behind Transformers and Attentions

Transformers requires 3 main parts fundamentally: Word Embedding, Positional Encoding, and Attention, context aware embeddings.

  • Word Embedding converts words, bits of words and symbols, collectively called tokens, into numbers. (We need this because Transformers are a type of Neural Networks that only have numbers for input values.)

  • Positional Encoding helps keep track of word order.

  • Attention — self-attention(works by seeing how similar each word is to all of the words in the sentences, including itself)

  • Context aware embeddings can help cluster similar sentences and document

The Matrix Math for calculating self-attention

It is more common to use 512 or more numbers to represent each word.

Encoded values * Query weights(T) = Q
Encoded values * Key Weights(T) = K
Encoded values * Value Weights(T) = V

Dot products can be used as an unscaled measure of similarity between two things, and this metric is closely related to something called the Cosine Similarity. The big difference is that the Cosine Similarity scales the Dot Product to be between -1 to 1.

Square root of dk is the dimension of the Key matrix.

The percentages that come out of the softmax function tell us how much influence each word should have on the final encoding for any given word.

Coding self-attention in PyTorch

Import torch: Tensors are multidimensional lists optimized for neural networks.

import torch: for tensor operations.
import torch.nn: for the Module and Linear classes, and a bunch of other helper functions.
import torch.nn.functional: to access the softmax() function that we will use when calculating attention.

Define SelfAttention that inherits from nn.Module, which is the base class for all neural network modules that you make with PyTorch.

Create a __init__() method.

Self-attention vs Masked self-attention

Encoder-only transformer: Word embedding, positional encoding, self-attention, context-aware embeddings.

Decoder-only transformer: Word embedding, positional encoding, masked self-attention, generative inputs (e.g., ChatGPT).

Self-attention can look at words before and after the word of interest.

Masked self-attention ignores the words that come after the word of interest.

The matrix math for calculating masked self-attention

We add a new matrix, M for Mask, to the scaled similarities.

(The purpose of the mask is to prevent tokens from including anything that comes after them when calculating attention.)

coding masked self-attention in PyTorch

The forward() method is where we actually calculate the masked self-attention values for each token.

The True values correspond to attention values that we want to mask out.

encoder-decoder attention(cross-attention

It uses the output from the Encoder to calculate the keys and values.

The first Transformer was based on something called a Seq2Seq, or an Encoder-Decoder model.

multi-head attention

It uses the output from the Encoder to calculate the keys and values.

The first Transformer was based on something called a Seq2Seq, or an Encoder-Decoder model.

coding encoder-decoder attention and multi-head attention in PyTorch


                    import torch
                    import torch.nn as nn
                    
                    # Define input parameters
                    embed_dim = 64  # Dimension of input embeddings
                    num_heads = 8   # Number of attention heads
                    seq_len = 10    # Sequence length
                    batch_size = 32 # Batch size
                    
                    # Create Multihead Attention layer
                    mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
                    
                    # Random input tensor (batch_size, seq_len, embed_dim)
                    query = torch.rand(batch_size, seq_len, embed_dim)
                    key = torch.rand(batch_size, seq_len, embed_dim)
                    value = torch.rand(batch_size, seq_len, embed_dim)
                    
                    # Apply multi-head attention
                    output, attn_weights = mha(query, key, value)
                    
                    print(output.shape)       # Output: (batch_size, seq_len, embed_dim)
                    print(attn_weights.shape) # Attention weights: (batch_size, num_heads, seq_len, seq_len)
                

Reference:

Deep learning from Professor Andrew Ng