Coding Attention Mechanisms

Code repository: rasbt/LLMs-from-scratch

This chapter explores the fundamentals of self-attention mechanisms and their implementation in natural language processing, progressing from basic attention to multi-head attention. We will implement these concepts step-by-step using Python code.


1. Attending to Different Parts of the Input with Self-Attention

Self-attention allows a model to dynamically focus on different parts of an input sequence based on their relevance. Below is a simple implementation broken into key steps.

1.1 Simple Attention Mechanism

The simple attention mechanism computes a context vector in three steps: input embedding to attention scores, attention scores to weights, and weights to context vector.

Step 1: Input Embedding -> ω (Attention Scores)

  • Goal: Compute relevance scores (ω) between a query and each input token using a dot product.
  • Implementation:
    • Input is an embedding matrix inputs with shape (num_tokens, d_in).
    • For each token, compute its dot product with a query vector.
  • Methods:
    • Manual: Use a for loop with torch.dot(inputs[i], query).
    • Example Code (assuming query is defined):
      omega = torch.zeros(num_tokens)
      for i in range(num_tokens):
          omega[i] = torch.dot(inputs[i], query)
      

Step 2: ω (Attention Scores) -> α (Attention Weights)

  • Goal: Normalize the attention scores so their sum is 1, yielding attention weights (α).
  • Purpose: Prevent large values and improve numerical stability.
  • Implementation:
    • Manual: alpha = omega / omega.sum().
    • Preferred: Use torch.softmax(omega, dim=0) for automatic normalization.
  • Code:
    attn_weights = torch.softmax(omega, dim=0)
    

Step 3: α (Attention Weights) -> z (Context Vector)

  • Goal: Compute a weighted sum of input tokens using attention weights to produce the context vector (z).
  • Implementation:
    • Manual: Use a for loop to multiply and sum.
    • Example Code:
      context_vec = torch.zeros(d_in)
      for i in range(num_tokens):
          context_vec += attn_weights[i] * inputs[i]
      

1.2 Computing Attention Weights for All Input Tokens

To improve efficiency, we compute attention weights for all tokens simultaneously using matrix operations.

Implementation

  • Matrix Multiplication:

    • Input inputs has shape (num_tokens, d_in).
    • Attention scores: attn_scores = inputs @ inputs.T, shape (num_tokens, num_tokens).
    • Normalize: attn_weights = torch.softmax(attn_scores, dim=1).
    • Context vectors: all_context_vecs = attn_weights @ inputs, shape (num_tokens, d_in).
  • Code:

    attn_scores = inputs @ inputs.T
    attn_weights = torch.softmax(attn_scores, dim=1)
    all_context_vecs = attn_weights @ inputs
    
  • Explanation:

    • inputs @ inputs.T computes dot products between all token pairs.
    • dim=1 normalizes across rows, ensuring each row sums to 1.
    • all_context_vecs provides context-aware representations for all tokens.

2. Implementing Self-Attention with Trainable Weights

Unlike simple attention, self-attention with trainable weights introduces learnable query (Q), key (K), and value (V) matrices, decoupling the context vector’s dimensionality from the input embeddings.

2.1 SelfAttention_v1: Basic Implementation

  • Core Idea: Use trainable weight matrices W_q, W_k, and W_v to generate queries, keys, and values.
  • Code:
    import torch
    import torch.nn as nn
    
    class SelfAttention_v1(nn.Module):
        def __init__(self, d_in, d_out):
            super().__init__()
            self.W_query = nn.Parameter(torch.rand(d_in, d_out))
            self.W_key = nn.Parameter(torch.rand(d_in, d_out))
            self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
        def forward(self, x):
            keys = x @ self.W_key        # (num_tokens, d_out)
            queries = x @ self.W_query   # (num_tokens, d_out)
            values = x @ self.W_value    # (num_tokens, d_out)
    
            attn_scores = queries @ keys.T  # (num_tokens, num_tokens)
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            context_vec = attn_weights @ values  # (num_tokens, d_out)
            return context_vec
    
    torch.manual_seed(123)
    d_in, d_out = 3, 2
    inputs = torch.rand(6, d_in)  # 6 tokens
    sa_v1 = SelfAttention_v1(d_in, d_out)
    print(sa_v1(inputs))
    
  • Notes:
    • keys.shape[-1]**0.5 scales the scores (Scaled Dot-Product Attention) to stabilize gradients.
    • Output shape is (num_tokens, d_out).

2.2 SelfAttention_v2: Using nn.Linear

  • Improvement: Replace manual matrix multiplication with nn.Linear for flexibility (e.g., bias option).
  • Code:
    class SelfAttention_v2(nn.Module):
        def __init__(self, d_in, d_out, qkv_bias=False):
            super().__init__()
            self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
        def forward(self, x):
            keys = self.W_key(x)
            queries = self.W_query(x)
            values = self.W_value(x)
    
            attn_scores = queries @ keys.T
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            context_vec = attn_weights @ values
            return context_vec
    
    torch.manual_seed(789)
    sa_v2 = SelfAttention_v2(d_in, d_out)
    print(sa_v2(inputs))
    

3. Hiding Future Words with Causal Attention

Causal attention ensures a token only attends to previous tokens, making it suitable for autoregressive tasks (e.g., language generation).

3.1 Applying a Causal Attention Mask

  • Method: Apply an upper triangular mask to attention scores, setting future token scores to -inf, which become 0 after softmax.

  • Code:

    context_length = inputs.shape[0]  # e.g., 6
    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
    masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
    attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
    
  • Explanation:

    • torch.triu creates an upper triangular matrix; diagonal=1 sets elements above the main diagonal to 1.
    • masked_fill replaces masked positions with -inf.
    • After softmax, masked positions become 0, hiding future tokens.

3.2 Masking Additional Attention Weights with Dropout

  • Goal: Randomly drop attention weights to prevent overfitting and reduce positional dependency.
  • Implementation: Add nn.Dropout after computing attn_weights.

3.3 Implementing a Compact Causal Self-Attention Class

  • Complete Implementation:

    class CausalAttention(nn.Module):
        def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
            super().__init__()
            self.d_out = d_out
            self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.dropout = nn.Dropout(dropout)
            self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
        def forward(self, x):
            b, num_tokens, d_in = x.shape  # Batch dimension support
            keys = self.W_key(x)           # (b, num_tokens, d_out)
            queries = self.W_query(x)      # (b, num_tokens, d_out)
            values = self.W_value(x)       # (b, num_tokens, d_out)
    
            attn_scores = queries @ keys.transpose(1, 2)  # (b, num_tokens, num_tokens)
            attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            attn_weights = self.dropout(attn_weights)
            context_vec = attn_weights @ values  # (b, num_tokens, d_out)
            return context_vec
    
    torch.manual_seed(123)
    batch = torch.stack((inputs, inputs), dim=0)  # (2, 6, 3)
    context_length = batch.shape[1]
    ca = CausalAttention(d_in, d_out, context_length, 0.0)
    context_vecs = ca(batch)
    print(context_vecs)
    print("context_vecs.shape:", context_vecs.shape)  # (2, 6, 2)
    
  • Key Points:

    • Supports batched inputs with shape (batch_size, num_tokens, d_in).
    • transpose(1, 2) adjusts for batch matrix multiplication.
    • self.mask[:num_tokens, :num_tokens] dynamically resizes the mask for variable input lengths.

4. Extending Single-Head Attention to Multi-Head Attention

Multi-head attention enhances representation power by computing attention in parallel across multiple heads.

4.1 MultiHeadAttentionWrapper

  • Implementation:

    class MultiHeadAttentionWrapper(nn.Module):
        def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
            super().__init__()
            self.heads = nn.ModuleList(
                [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
                 for _ in range(num_heads)]
            )
    
        def forward(self, x):
            return torch.cat([head(x) for head in self.heads], dim=-1)
    
    torch.manual_seed(123)
    mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
    context_vecs = mha(batch)
    print(context_vecs)
    print("context_vecs.shape:", context_vecs.shape)  # (2, 6, 4)
    
  • Explanation:

    • Each head outputs (b, num_tokens, d_out); with num_heads=2, the concatenated output is (b, num_tokens, d_out * 2).
    • This allows the model to capture diverse patterns in the input.

Summary

  1. Simple Attention: Computes context vectors via dot products and normalization.
  2. Self-Attention: Introduces trainable Q, K, V matrices for flexibility.
  3. Causal Attention: Masks future tokens for autoregressive tasks.
  4. Multi-Head Attention: Parallelizes attention heads to enhance feature extraction.