第三章:编码注意力机制

代码仓库: rasbt/LLMs-from-scratch

本章探讨自注意力机制的基本原理及其在自然语言处理中的实现,从简单注意力逐步推进到多头注意力。我们将通过Python代码一步步实现这些概念。


1. 通过自注意力关注输入的不同部分

自注意力机制允许模型根据输入序列中各部分的关联性动态调整关注焦点。以下是一个简易实现的步骤分解。

1.1 简单注意力机制

简单注意力机制通过三个步骤计算上下文向量:输入嵌入到注意力得分,注意力得分到注意力权重,再到上下文向量。

步骤 1:输入嵌入 -> ω (注意力得分)

  • 目标:通过点积计算查询(Query)与每个输入token的相关性得分(ω)。
  • 实现
    • 输入是一个嵌入矩阵inputs,形状为(num_tokens, d_in)
    • 对每个token,计算其与查询向量的点积。
  • 方法
    • 手动实现:使用for循环,计算torch.dot(inputs[i], query)
    • 示例代码(假设query已定义):
      omega = torch.zeros(num_tokens)
      for i in range(num_tokens):
          omega[i] = torch.dot(inputs[i], query)
      

步骤 2:ω (注意力得分) -> α (注意力权重)

  • 目标:将注意力得分归一化,使其和为1,得到注意力权重(α)。
  • 目的:防止数值过大,提高数值稳定性。
  • 实现
    • 手动计算:alpha = omega / omega.sum()
    • 推荐方法:使用torch.softmax(omega, dim=0)自动归一化。
  • 代码
    attn_weights = torch.softmax(omega, dim=0)
    

步骤 3:α (注意力权重) -> z (上下文向量)

  • 目标:根据注意力权重对输入token进行加权求和,生成上下文向量(z)。
  • 实现
    • 手动实现:使用for循环计算加权和。
    • 示例代码:
      context_vec = torch.zeros(d_in)
      for i in range(num_tokens):
          context_vec += attn_weights[i] * inputs[i]
      

1.2 为所有输入token计算注意力权重

为了提高效率,我们可以一次性计算所有token的注意力权重,避免逐个计算。

实现方法

  • 矩阵乘法

    • 输入inputs形状为(num_tokens, d_in)
    • 注意力得分:attn_scores = inputs @ inputs.T,形状为(num_tokens, num_tokens)
    • 归一化:attn_weights = torch.softmax(attn_scores, dim=1)
    • 上下文向量:all_context_vecs = attn_weights @ inputs,形状为(num_tokens, d_in)
  • 代码

    attn_scores = inputs @ inputs.T
    attn_weights = torch.softmax(attn_scores, dim=1)
    all_context_vecs = attn_weights @ inputs
    
  • 解释

    • inputs @ inputs.T计算所有token对之间的点积。
    • dim=1按行归一化,确保每行和为1。
    • all_context_vecs为每个token提供上下文感知表示。

2. 实现带可训练权重的自注意力

与简单注意力不同,带可训练权重的自注意力引入了可学习的查询(Q)、键(K)和值(V)矩阵,使上下文向量的维度与输入嵌入解耦。

2.1 SelfAttention_v1:基础实现

  • 核心思想:使用可训练权重矩阵W_qW_kW_v生成查询、键和值。
  • 代码
    import torch
    import torch.nn as nn
    
    class SelfAttention_v1(nn.Module):
        def __init__(self, d_in, d_out):
            super().__init__()
            self.W_query = nn.Parameter(torch.rand(d_in, d_out))
            self.W_key = nn.Parameter(torch.rand(d_in, d_out))
            self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
        def forward(self, x):
            keys = x @ self.W_key        # (num_tokens, d_out)
            queries = x @ self.W_query   # (num_tokens, d_out)
            values = x @ self.W_value    # (num_tokens, d_out)
    
            attn_scores = queries @ keys.T  # (num_tokens, num_tokens)
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            context_vec = attn_weights @ values  # (num_tokens, d_out)
            return context_vec
    
    torch.manual_seed(123)
    d_in, d_out = 3, 2
    inputs = torch.rand(6, d_in)  # 6个token
    sa_v1 = SelfAttention_v1(d_in, d_out)
    print(sa_v1(inputs))
    
  • 说明
    • keys.shape[-1]**0.5是对注意力得分的缩放(Scaled Dot-Product Attention),用于稳定梯度。
    • 输出形状为(num_tokens, d_out)

2.2 SelfAttention_v2:使用nn.Linear

  • 改进:用nn.Linear替换手动矩阵乘法,支持偏置选项。
  • 代码
    class SelfAttention_v2(nn.Module):
        def __init__(self, d_in, d_out, qkv_bias=False):
            super().__init__()
            self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
        def forward(self, x):
            keys = self.W_key(x)
            queries = self.W_query(x)
            values = self.W_value(x)
    
            attn_scores = queries @ keys.T
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            context_vec = attn_weights @ values
            return context_vec
    
    torch.manual_seed(789)
    sa_v2 = SelfAttention_v2(d_in, d_out)
    print(sa_v2(inputs))
    

3. 使用因果注意力隐藏未来信息

因果注意力确保当前token只能关注之前的token,适用于自回归任务(如语言生成)。

3.1 应用因果注意力掩码

  • 方法:在注意力得分上应用上三角掩码,将未来token的得分设为-inf,经softmax后变为0。

  • 代码

    context_length = inputs.shape[0]  # 例如6
    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
    masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
    attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
    
  • 解释

    • torch.triu生成上三角矩阵,diagonal=1表示主对角线以上为1。
    • masked_fill将掩码为1的位置替换为-inf
    • Softmax后,-inf变为0,屏蔽未来信息。

3.2 使用Dropout掩码额外注意力权重

  • 目标:通过随机丢弃注意力权重,防止过拟合和对特定位置的依赖。
  • 实现:在attn_weights后添加nn.Dropout

3.3 实现紧凑的因果自注意力类

  • 完整实现

    class CausalAttention(nn.Module):
        def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
            super().__init__()
            self.d_out = d_out
            self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
            self.dropout = nn.Dropout(dropout)
            self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
        def forward(self, x):
            b, num_tokens, d_in = x.shape  # 支持批量输入
            keys = self.W_key(x)           # (b, num_tokens, d_out)
            queries = self.W_query(x)      # (b, num_tokens, d_out)
            values = self.W_value(x)       # (b, num_tokens, d_out)
    
            attn_scores = queries @ keys.transpose(1, 2)  # (b, num_tokens, num_tokens)
            attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            attn_weights = self.dropout(attn_weights)
            context_vec = attn_weights @ values  # (b, num_tokens, d_out)
            return context_vec
    
    torch.manual_seed(123)
    batch = torch.stack((inputs, inputs), dim=0)  # (2, 6, 3)
    context_length = batch.shape[1]
    ca = CausalAttention(d_in, d_out, context_length, 0.0)
    context_vecs = ca(batch)
    print(context_vecs)
    print("context_vecs.shape:", context_vecs.shape)  # (2, 6, 2)
    
  • 关键点

    • 支持批量输入,x.shape = (batch_size, num_tokens, d_in)
    • transpose(1, 2)适应批量矩阵乘法。
    • self.mask[:num_tokens, :num_tokens]动态调整掩码大小。

4. 将单头注意力扩展到多头注意力

多头注意力通过并行计算多个注意力头,增强模型捕捉不同子空间信息的能力。

4.1 MultiHeadAttentionWrapper

  • 实现

    class MultiHeadAttentionWrapper(nn.Module):
        def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
            super().__init__()
            self.heads = nn.ModuleList(
                [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
                 for _ in range(num_heads)]
            )
    
        def forward(self, x):
            return torch.cat([head(x) for head in self.heads], dim=-1)
    
    torch.manual_seed(123)
    mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
    context_vecs = mha(batch)
    print(context_vecs)
    print("context_vecs.shape:", context_vecs.shape)  # (2, 6, 4)
    
  • 解释

    • 每个头输出(b, num_tokens, d_out)num_heads=2时拼接后为(b, num_tokens, d_out * 2)
    • 提高了模型的表达能力。

总结

  1. 简单注意力:通过点积和归一化计算上下文向量。
  2. 自注意力:引入可训练的Q、K、V矩阵。
  3. 因果注意力:通过掩码屏蔽未来信息,支持自回归任务。
  4. 多头注意力:并行多个注意力头,增强特征提取能力。