第三章:编码注意力机制
代码仓库: rasbt/LLMs-from-scratch
本章探讨自注意力机制的基本原理及其在自然语言处理中的实现,从简单注意力逐步推进到多头注意力。我们将通过Python代码一步步实现这些概念。
1. 通过自注意力关注输入的不同部分
自注意力机制允许模型根据输入序列中各部分的关联性动态调整关注焦点。以下是一个简易实现的步骤分解。
1.1 简单注意力机制
简单注意力机制通过三个步骤计算上下文向量:输入嵌入到注意力得分,注意力得分到注意力权重,再到上下文向量。
步骤 1:输入嵌入 -> ω (注意力得分)
- 目标:通过点积计算查询(Query)与每个输入token的相关性得分(ω)。
- 实现:
- 输入是一个嵌入矩阵
inputs
,形状为(num_tokens, d_in)
。 - 对每个token,计算其与查询向量的点积。
- 输入是一个嵌入矩阵
- 方法:
- 手动实现:使用
for
循环,计算torch.dot(inputs[i], query)
。 - 示例代码(假设
query
已定义):omega = torch.zeros(num_tokens) for i in range(num_tokens): omega[i] = torch.dot(inputs[i], query)
- 手动实现:使用
步骤 2:ω (注意力得分) -> α (注意力权重)
- 目标:将注意力得分归一化,使其和为1,得到注意力权重(α)。
- 目的:防止数值过大,提高数值稳定性。
- 实现:
- 手动计算:
alpha = omega / omega.sum()
。 - 推荐方法:使用
torch.softmax(omega, dim=0)
自动归一化。
- 手动计算:
- 代码:
attn_weights = torch.softmax(omega, dim=0)
步骤 3:α (注意力权重) -> z (上下文向量)
- 目标:根据注意力权重对输入token进行加权求和,生成上下文向量(z)。
- 实现:
- 手动实现:使用
for
循环计算加权和。 - 示例代码:
context_vec = torch.zeros(d_in) for i in range(num_tokens): context_vec += attn_weights[i] * inputs[i]
- 手动实现:使用
1.2 为所有输入token计算注意力权重
为了提高效率,我们可以一次性计算所有token的注意力权重,避免逐个计算。
实现方法
-
矩阵乘法:
- 输入
inputs
形状为(num_tokens, d_in)
。 - 注意力得分:
attn_scores = inputs @ inputs.T
,形状为(num_tokens, num_tokens)
。 - 归一化:
attn_weights = torch.softmax(attn_scores, dim=1)
。 - 上下文向量:
all_context_vecs = attn_weights @ inputs
,形状为(num_tokens, d_in)
。
- 输入
-
代码:
attn_scores = inputs @ inputs.T attn_weights = torch.softmax(attn_scores, dim=1) all_context_vecs = attn_weights @ inputs
-
解释:
inputs @ inputs.T
计算所有token对之间的点积。dim=1
按行归一化,确保每行和为1。all_context_vecs
为每个token提供上下文感知表示。
2. 实现带可训练权重的自注意力
与简单注意力不同,带可训练权重的自注意力引入了可学习的查询(Q)、键(K)和值(V)矩阵,使上下文向量的维度与输入嵌入解耦。
2.1 SelfAttention_v1:基础实现
- 核心思想:使用可训练权重矩阵
W_q
、W_k
、W_v
生成查询、键和值。 - 代码:
import torch import torch.nn as nn class SelfAttention_v1(nn.Module): def __init__(self, d_in, d_out): super().__init__() self.W_query = nn.Parameter(torch.rand(d_in, d_out)) self.W_key = nn.Parameter(torch.rand(d_in, d_out)) self.W_value = nn.Parameter(torch.rand(d_in, d_out)) def forward(self, x): keys = x @ self.W_key # (num_tokens, d_out) queries = x @ self.W_query # (num_tokens, d_out) values = x @ self.W_value # (num_tokens, d_out) attn_scores = queries @ keys.T # (num_tokens, num_tokens) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) context_vec = attn_weights @ values # (num_tokens, d_out) return context_vec torch.manual_seed(123) d_in, d_out = 3, 2 inputs = torch.rand(6, d_in) # 6个token sa_v1 = SelfAttention_v1(d_in, d_out) print(sa_v1(inputs))
- 说明:
keys.shape[-1]**0.5
是对注意力得分的缩放(Scaled Dot-Product Attention),用于稳定梯度。- 输出形状为
(num_tokens, d_out)
。
2.2 SelfAttention_v2:使用nn.Linear
- 改进:用
nn.Linear
替换手动矩阵乘法,支持偏置选项。 - 代码:
class SelfAttention_v2(nn.Module): def __init__(self, d_in, d_out, qkv_bias=False): super().__init__() self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) def forward(self, x): keys = self.W_key(x) queries = self.W_query(x) values = self.W_value(x) attn_scores = queries @ keys.T attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) context_vec = attn_weights @ values return context_vec torch.manual_seed(789) sa_v2 = SelfAttention_v2(d_in, d_out) print(sa_v2(inputs))
3. 使用因果注意力隐藏未来信息
因果注意力确保当前token只能关注之前的token,适用于自回归任务(如语言生成)。
3.1 应用因果注意力掩码
-
方法:在注意力得分上应用上三角掩码,将未来token的得分设为
-inf
,经softmax后变为0。 -
代码:
context_length = inputs.shape[0] # 例如6 mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) masked = attn_scores.masked_fill(mask.bool(), -torch.inf) attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
-
解释:
torch.triu
生成上三角矩阵,diagonal=1
表示主对角线以上为1。masked_fill
将掩码为1的位置替换为-inf
。- Softmax后,
-inf
变为0,屏蔽未来信息。
3.2 使用Dropout掩码额外注意力权重
- 目标:通过随机丢弃注意力权重,防止过拟合和对特定位置的依赖。
- 实现:在
attn_weights
后添加nn.Dropout
。
3.3 实现紧凑的因果自注意力类
-
完整实现:
class CausalAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False): super().__init__() self.d_out = d_out self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) self.dropout = nn.Dropout(dropout) self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) def forward(self, x): b, num_tokens, d_in = x.shape # 支持批量输入 keys = self.W_key(x) # (b, num_tokens, d_out) queries = self.W_query(x) # (b, num_tokens, d_out) values = self.W_value(x) # (b, num_tokens, d_out) attn_scores = queries @ keys.transpose(1, 2) # (b, num_tokens, num_tokens) attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = self.dropout(attn_weights) context_vec = attn_weights @ values # (b, num_tokens, d_out) return context_vec torch.manual_seed(123) batch = torch.stack((inputs, inputs), dim=0) # (2, 6, 3) context_length = batch.shape[1] ca = CausalAttention(d_in, d_out, context_length, 0.0) context_vecs = ca(batch) print(context_vecs) print("context_vecs.shape:", context_vecs.shape) # (2, 6, 2)
-
关键点:
- 支持批量输入,
x.shape = (batch_size, num_tokens, d_in)
。 transpose(1, 2)
适应批量矩阵乘法。self.mask[:num_tokens, :num_tokens]
动态调整掩码大小。
- 支持批量输入,
4. 将单头注意力扩展到多头注意力
多头注意力通过并行计算多个注意力头,增强模型捕捉不同子空间信息的能力。
4.1 MultiHeadAttentionWrapper
-
实现:
class MultiHeadAttentionWrapper(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() self.heads = nn.ModuleList( [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)] ) def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=-1) torch.manual_seed(123) mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2) context_vecs = mha(batch) print(context_vecs) print("context_vecs.shape:", context_vecs.shape) # (2, 6, 4)
-
解释:
- 每个头输出
(b, num_tokens, d_out)
,num_heads=2
时拼接后为(b, num_tokens, d_out * 2)
。 - 提高了模型的表达能力。
- 每个头输出
总结
- 简单注意力:通过点积和归一化计算上下文向量。
- 自注意力:引入可训练的Q、K、V矩阵。
- 因果注意力:通过掩码屏蔽未来信息,支持自回归任务。
- 多头注意力:并行多个注意力头,增强特征提取能力。