Coding Attention Mechanisms
Code repository: rasbt/LLMs-from-scratch
This chapter explores the fundamentals of self-attention mechanisms and their implementation in natural language processing, progressing from basic attention to multi-head attention. We will implement these concepts step-by-step using Python code.
1. Attending to Different Parts of the Input with Self-Attention
Self-attention allows a model to dynamically focus on different parts of an input sequence based on their relevance. Below is a simple implementation broken into key steps.
1.1 Simple Attention Mechanism
The simple attention mechanism computes a context vector in three steps: input embedding to attention scores, attention scores to weights, and weights to context vector.
Step 1: Input Embedding -> ω (Attention Scores)
- Goal: Compute relevance scores (ω) between a query and each input token using a dot product.
- Implementation:
- Input is an embedding matrix
inputs
with shape(num_tokens, d_in)
. - For each token, compute its dot product with a query vector.
- Input is an embedding matrix
- Methods:
- Manual: Use a
for
loop withtorch.dot(inputs[i], query)
. - Example Code (assuming
query
is defined):omega = torch.zeros(num_tokens) for i in range(num_tokens): omega[i] = torch.dot(inputs[i], query)
- Manual: Use a
Step 2: ω (Attention Scores) -> α (Attention Weights)
- Goal: Normalize the attention scores so their sum is 1, yielding attention weights (α).
- Purpose: Prevent large values and improve numerical stability.
- Implementation:
- Manual:
alpha = omega / omega.sum()
. - Preferred: Use
torch.softmax(omega, dim=0)
for automatic normalization.
- Manual:
- Code:
attn_weights = torch.softmax(omega, dim=0)
Step 3: α (Attention Weights) -> z (Context Vector)
- Goal: Compute a weighted sum of input tokens using attention weights to produce the context vector (z).
- Implementation:
- Manual: Use a
for
loop to multiply and sum. - Example Code:
context_vec = torch.zeros(d_in) for i in range(num_tokens): context_vec += attn_weights[i] * inputs[i]
- Manual: Use a
1.2 Computing Attention Weights for All Input Tokens
To improve efficiency, we compute attention weights for all tokens simultaneously using matrix operations.
Implementation
-
Matrix Multiplication:
- Input
inputs
has shape(num_tokens, d_in)
. - Attention scores:
attn_scores = inputs @ inputs.T
, shape(num_tokens, num_tokens)
. - Normalize:
attn_weights = torch.softmax(attn_scores, dim=1)
. - Context vectors:
all_context_vecs = attn_weights @ inputs
, shape(num_tokens, d_in)
.
- Input
-
Code:
attn_scores = inputs @ inputs.T attn_weights = torch.softmax(attn_scores, dim=1) all_context_vecs = attn_weights @ inputs
-
Explanation:
inputs @ inputs.T
computes dot products between all token pairs.dim=1
normalizes across rows, ensuring each row sums to 1.all_context_vecs
provides context-aware representations for all tokens.
2. Implementing Self-Attention with Trainable Weights
Unlike simple attention, self-attention with trainable weights introduces learnable query (Q), key (K), and value (V) matrices, decoupling the context vector’s dimensionality from the input embeddings.
2.1 SelfAttention_v1: Basic Implementation
- Core Idea: Use trainable weight matrices
W_q
,W_k
, andW_v
to generate queries, keys, and values. - Code:
import torch import torch.nn as nn class SelfAttention_v1(nn.Module): def __init__(self, d_in, d_out): super().__init__() self.W_query = nn.Parameter(torch.rand(d_in, d_out)) self.W_key = nn.Parameter(torch.rand(d_in, d_out)) self.W_value = nn.Parameter(torch.rand(d_in, d_out)) def forward(self, x): keys = x @ self.W_key # (num_tokens, d_out) queries = x @ self.W_query # (num_tokens, d_out) values = x @ self.W_value # (num_tokens, d_out) attn_scores = queries @ keys.T # (num_tokens, num_tokens) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) context_vec = attn_weights @ values # (num_tokens, d_out) return context_vec torch.manual_seed(123) d_in, d_out = 3, 2 inputs = torch.rand(6, d_in) # 6 tokens sa_v1 = SelfAttention_v1(d_in, d_out) print(sa_v1(inputs))
- Notes:
keys.shape[-1]**0.5
scales the scores (Scaled Dot-Product Attention) to stabilize gradients.- Output shape is
(num_tokens, d_out)
.
2.2 SelfAttention_v2: Using nn.Linear
- Improvement: Replace manual matrix multiplication with
nn.Linear
for flexibility (e.g., bias option). - Code:
class SelfAttention_v2(nn.Module): def __init__(self, d_in, d_out, qkv_bias=False): super().__init__() self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) def forward(self, x): keys = self.W_key(x) queries = self.W_query(x) values = self.W_value(x) attn_scores = queries @ keys.T attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) context_vec = attn_weights @ values return context_vec torch.manual_seed(789) sa_v2 = SelfAttention_v2(d_in, d_out) print(sa_v2(inputs))
3. Hiding Future Words with Causal Attention
Causal attention ensures a token only attends to previous tokens, making it suitable for autoregressive tasks (e.g., language generation).
3.1 Applying a Causal Attention Mask
-
Method: Apply an upper triangular mask to attention scores, setting future token scores to
-inf
, which become 0 after softmax. -
Code:
context_length = inputs.shape[0] # e.g., 6 mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) masked = attn_scores.masked_fill(mask.bool(), -torch.inf) attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
-
Explanation:
torch.triu
creates an upper triangular matrix;diagonal=1
sets elements above the main diagonal to 1.masked_fill
replaces masked positions with-inf
.- After softmax, masked positions become 0, hiding future tokens.
3.2 Masking Additional Attention Weights with Dropout
- Goal: Randomly drop attention weights to prevent overfitting and reduce positional dependency.
- Implementation: Add
nn.Dropout
after computingattn_weights
.
3.3 Implementing a Compact Causal Self-Attention Class
-
Complete Implementation:
class CausalAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False): super().__init__() self.d_out = d_out self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) self.dropout = nn.Dropout(dropout) self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) def forward(self, x): b, num_tokens, d_in = x.shape # Batch dimension support keys = self.W_key(x) # (b, num_tokens, d_out) queries = self.W_query(x) # (b, num_tokens, d_out) values = self.W_value(x) # (b, num_tokens, d_out) attn_scores = queries @ keys.transpose(1, 2) # (b, num_tokens, num_tokens) attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = self.dropout(attn_weights) context_vec = attn_weights @ values # (b, num_tokens, d_out) return context_vec torch.manual_seed(123) batch = torch.stack((inputs, inputs), dim=0) # (2, 6, 3) context_length = batch.shape[1] ca = CausalAttention(d_in, d_out, context_length, 0.0) context_vecs = ca(batch) print(context_vecs) print("context_vecs.shape:", context_vecs.shape) # (2, 6, 2)
-
Key Points:
- Supports batched inputs with shape
(batch_size, num_tokens, d_in)
. transpose(1, 2)
adjusts for batch matrix multiplication.self.mask[:num_tokens, :num_tokens]
dynamically resizes the mask for variable input lengths.
- Supports batched inputs with shape
4. Extending Single-Head Attention to Multi-Head Attention
Multi-head attention enhances representation power by computing attention in parallel across multiple heads.
4.1 MultiHeadAttentionWrapper
-
Implementation:
class MultiHeadAttentionWrapper(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() self.heads = nn.ModuleList( [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)] ) def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=-1) torch.manual_seed(123) mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2) context_vecs = mha(batch) print(context_vecs) print("context_vecs.shape:", context_vecs.shape) # (2, 6, 4)
-
Explanation:
- Each head outputs
(b, num_tokens, d_out)
; withnum_heads=2
, the concatenated output is(b, num_tokens, d_out * 2)
. - This allows the model to capture diverse patterns in the input.
- Each head outputs
Summary
- Simple Attention: Computes context vectors via dot products and normalization.
- Self-Attention: Introduces trainable Q, K, V matrices for flexibility.
- Causal Attention: Masks future tokens for autoregressive tasks.
- Multi-Head Attention: Parallelizes attention heads to enhance feature extraction.