注意力机制 - HappyLLM

注意力机制

传统神经网络

从 计算机视觉(Computer Vision,CV)为起源发展起来的神经网络,其核心架构有三种:

全连接神经网络(Feedforward Neural Network,FNN): 即每一层的神经元都和上下两层的每一个神经元完全连接,如图所示:

全连接神经网络

卷积神经网络(Convolutional Neural Network,CNN):训练参数量远小于全连接神经网络的卷积层来进行特征提取和学习。

卷积神经网络

循环神经网络(Recurrent Neural Network,RNN): 使用历史信息作为输入、包含环和自重复的网络,

循环神经网络

RNN 及 LSTM 的局限性

RNN、 LSTM 具有捕捉时序信息、适合序列生成的优点,却有两个缺陷:

  • ① 序列依序计算能很好模拟时序信息,但限制了计算机(GPU)并行计算的能力,导致模型参数量不算大,但计算时间成本高;
  • ② 难以捕捉长序列依赖关系。在 RNN 中,距离越远的输入之间的关系越难捕捉,需要将整个序列读入内存依次计算,限制了序列的长度。虽然 LSTM 中通过门机制进行了优化,但对于较远距离依赖关系的捕捉,依旧不如人意。

注意力机制

来源: 注意力机制源于CV领域,当关注一张图片,无需看清全部内容仅将注意力集中在重点部分。在NLP领域,将注意力集中在一个或几个 token,从而取得更好的计算效果。

核心变量: Query(查询值)、Key(键值)和 Value(真值)

  • 通过计算 QueryKey的相关性与Value加权求和,拟合序列中每个词与其他词的相关性。
  • Key 与 Query 相关性越高,则赋予的注意力权重越大;通过训练拟合,词向量能够表征语义信息,从而让语义相近的词在向量空间中距离更近,语义较远的词在向量空间中距离更远;
  • 欧式距离、点积可以衡量词向量的相似性
深入理解注意力机制

假设有一个字典, 字典的对应注意力机制中的键值 Key, 字典的就是真值 Value

1
2
3
4
5
{
"apple":10,
"banana":5,
"chair":2
}

现在想查询的值 Query"apple",那么通过将 QueryKey 匹配得到对应的 Value

但是当 Query 是一个包含多个 Key 的概念呢?比如,我们想要查找 "fruit",此时,应该将 "apple""banana" 都匹配到,但不能匹配到 "chair", 于是就会将 Key 对应的 Value 进行组合得到 Value

例如,当我们的 Query"fruit",可以分别给三个 Key 赋予如下权重:

1
2
3
4
5
{
"apple":0.6,
"banana":0.4,
"chair":0
}

那么,最终查询到的值应该是:value=0.610+0.45+02=8value=0.6∗10+0.4∗5+0∗2=8

给不同 Key 赋予的权重,就是我们所说的注意力分数, 即为了查询到 Query,应该赋予每一个 Key 多少注意力。但是,如何针对每一个 Query,计算出对应的注意力分数呢?

从直观上讲,我们可以认为 KeyQuery 相关性越高,则赋予的注意力权重越大。因此,如何找到一个合理的能够计算出正确的注意力分数的方法呢?

我们可以用点积(或欧式距离)衡量词向量的相似性,语义相近的词在向量空间中距离更近语义较远的词在向量空间中距离更远

vw=iviwiv·w = \sum_{i}v_iw_i

假设 Query“fruit”,对应的词向量为 qKey 对应的词向量为k=[vapple,vbanana,vchair]k = [v_{apple}, v_{banana}, v_{chair}] , 计算 Query 和每一个键的相似程度:

x=qKTx = qK^T

这里的 x 反映了 Query 和每一个 Key 的相似程度,然后使用一个 Softmax 层将其转化为和为 1 的权重:

softmax(x)i=exijexj\text{softmax}(x)_i = \frac{e^{xi}}{\sum_{j}e^{x_j}}

这样,得到的向量就能反映 Query 和每一个 Key 的相似程度,相加权重又为 1,即注意力分数。最后,再将得到的注意力分数和值向量做对应乘积即可。根据上述过程,就可以得到注意力机制计算的基本公式:

attention(Q,K,V)=softmax(qKT)vattention(Q,K,V) = softmax(qK^T)v

不过,此时的值是一个标量,因为只查询了一个 Query。我们可以将值转化为维度为dvd_v 的向量,将多个 Query词向量堆叠在一起形成矩阵 Q,得到公式:

attention(Q,K,V)=softmax(QKT)Vattention(Q,K,V) = softmax(QK^T)V

但是,当 QK 对应的维度dkd_k 比较大,可能会导致softmax梯度消失,使不同值之间的差异较大,从而影响梯度的稳定性。因此,我们要将 QK 乘积的结果做一个放缩:

attention(Q,K,V)=softmax(QKTdk)Vattention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

这也就是注意力机制的核心计算公式了。

代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 注意力函数
def attention(query, key, value, dropout=None):
'''
args:

query: 查询值矩阵 (batch_size, seq_len_q, d_k)
key: 键值矩阵 (batch_size, seq_len_k, d_k)
value: 真值矩阵 (batch_size, seq_len_k, seq_len_v)
'''
# 获取键向量的维度 d_k, 一般与查询值的维度相同
d_k = query.size(-1)

# 计算 Q 与 K 的内积并除以缩放因子 (根号d_k)
# transpose 相当于转置矩阵
# key.transpose(-2,-1) 后: (batch_size, d_k, seq_len_k)
# scores 形状: (batch_size, seq_len_q, seq_len_k)
scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(d_k)

# softmax: 每一行(每个Query)独立做softmax,使得每个Query对所有Key的注意力权重之和为1。
p_attn = scores.softmax(dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
# 采样
# 根据计算结果对 value 进行加权求和
return torch.matmul(p_attn,value), p_attn

自注意力机制

注意力机制的本质是对两段序列的元素依次进行相似度计算,寻找一个序列的每个元素对另一个序列的每个元素的相关度,进行加权求和,即分配注意力。

在 Transformer 的 Decoder 结构中,Q 来自于 Decoder 的输入,K 与 V 来自于 Encoder 的输出,从而拟合了编码信息与历史信息之间的关系,便于综合这两种信息实现预测。

在 Transformer 的 Encoder 结构中,使用的是自注意力(self-attention)机制,即计算序列中每个元素(token)对该序列其他元素(token)的注意力分布,从而建模文本之间的依赖关系。

在代码实现中,self-attention 机制通过给 Q、K、V 的输入传入同一个参数实现:

1
2
# attention 为上文定义的注意力计算函数
attention(x, x, x)

Mask自注意力机制

掩码的作用: 遮蔽一些特定位置的 token,模型在学习的过程中,会忽略掉被遮蔽的 token。

动机: 让模型只能使用历史信息预测而不能看到未来信息。类似于 n-gram 模型,即对于一个文本序列,根据之前的 token 来预测下一个 token,直到将整个文本序列补全。

示例:

对于一个文本序列: 【BOS】I like you【EOS】,模型会按如下顺序进行预测和学习:

Step 1:输入 【BOS】,输出 I
Step 2:输入 【BOS】I,输出 like
Step 3:输入 【BOS】I like,输出 you
Step 4:输入 【BOS】I like you,输出 【EOS】

理论上,若学习的语料足够多,则模型可以学会任意一种文本序列的建模方式,对任意文本进行补全。

问题提出: 上述过程是一个串行过程,需要先完成 Step 1,才能做 Step 2,逐步完成整个序列的补全,Transformer 相对于 RNN 的核心优势之一就是并行计算。若对于每一个训练语料,模型都需要串行完成上述过程才能进行学习和预测,很明显并没有并行计算,且计算效率低;

解决方法: 掩码自注意力生成一串掩码 [ Mask ],并行地输入到模型中,每一行输入中,模型只看前面的(未 Mask) token,来预测下一个 token。

【MASK】【MASK】【MASK】【MASK】
I 【MASK】 【MASK】【MASK】
I like 【MASK】【MASK】
I like you 【MASK】
I like you

Mask方式: 创建一个和输入同等长度的上三角矩阵作为注意力掩码,遮蔽输入即可;

  • 当输入维度为 (batch_size, seq_len, hidden_size)时,Mask 矩阵维度为 (1, seq_len, seq_len)(通过广播实现同一个 batch 中不同样本的计算)。
1
2
3
4
5
# 创建一个上三角矩阵遮蔽未来信息。
# 先通过 full 函数创建一个 1 * seq_len * seq_len 的矩阵
mask = torch.full((1, args.max_seq_len, args.max_seq_len), float("-inf"))
# triu 函数的功能是创建一个上三角矩阵
mask = torch.triu(mask, diagonal=1)

生成的 Mask 矩阵上三角位置的元素均为 -inf,其他位置的元素置为0。

在注意力计算时,将注意力分数与掩码求和,再进行 softmax 操作

1
2
3
# scores 为注意力分数,mask 为上文生成的掩码矩阵
scores = scores + mask[:, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)

通过求和,上三角区域(mask 位置)的注意力分数为 -inf, 下三角区域则不变。再进行 softmax 操作, -inf

经 softmax 后会被置为 0 ,从而忽略上三角区域注意力分数的计算。

多头注意力机制

问题提出: 注意力机制可以实现并行化、长依赖关系的拟合,但一次注意力计算只能拟合一种相关关系,单一的注意力机制很难全面拟合序列里的相关性。

多头注意力机制(Multi-Head Attention)通过将原始输入进行多组自注意力计算,然后拼接每一组得到的自注意力结果,使用一个线性层得到最后输出,从而拟合语句中的不同信息。

换句话说,n 个头有 n 组 3 个参数矩阵,每一组进行注意力计算,由于是不同的参数矩阵,因此反向传播实现了不同的注意力结果,然后将 n 个结果拼接起来输出。

但上述实现方法的时空复杂度较高,我们可以通过矩阵运算巧妙地实现多头并行计算,其核心逻辑在于使用三个组合矩阵来代替 n 个参数矩阵的组合,即先矩阵内积再拼接其实等同于先拼接矩阵再内积

MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O \\ \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)

在原论文中,作者也通过实验证实,多头注意力计算中,每个不同的注意力头能够拟合语句中的不同信息,如图2.4所示:

多头注意力

上层与下层分别是两个注意力头对同一段语句序列进行自注意力计算的结果,可以看到,对于不同的注意力头,能够拟合不同层次的相关信息。通过多个注意力头同时计算,能够更全面地拟合语句关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch.nn as nn
import torch
import math
import torch.nn.functional as F

''' 多头注意力计算模块 '''
class MultiHeadAttention(nn.Module):

def __init__(self, args: Model_args, is_causal = False):
# args: 模型配置参数
super().__init__()
# 隐藏层必须是头数的整数倍,后续会将 d_model 拆分为 d_k = d_model / head
assert args.dim % args.n_heads == 0
# 每个头的维度
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads

# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 dim × dim
self.wq = nn.Linear(args.dim, self.n_heads * self.head_dim, bias = False)
self.wk = nn.Linear(args.dim, self.n_heads * self.head_dim, bias = False)
self.wv = nn.Linear(args.dim, self.n_heads * self.head_dim, bias = False)

# 输出权重矩阵,维度为 dim × dim(head_dim = dim / n_heads)
self.wo = nn.Linear(self.n_heads * self.head_dim, args.dim, bias = False)

# 注意力的 dropout
self.attn_dropout = nn.Dropout(args.dropout)

# 残差连接的 dropout
self.resid_dropout = nn.Dropout(args.dropout)
self.is_causal = is_causal

if is_causal:
# 创建一个上三角矩阵,用于 MASK 未来信息
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal = 1)
# 注册为模型缓冲区
self.register_buffer("mask",mask)

def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

# 获取批次大小和序列长度, [batch_size, seq_len, dim]
batch_size, seq_len, _ = q.shape

# 计算 Q, K, V, 输入通过参数矩阵层
# shape: (batch_size, seq_len, dim) × (dim, dim) → (batch_size, seq_len, dim)
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)

# 将 Q, K, V 拆分成多头, 维度为(batch_size, seq_len, n_heads, dim // n_heads)
# 然后交换维度, 变成 (batch_size, n_heads, seq_len, dim // n_heads)
# 之所以要交换维度,因为计算注意力分数时主要是后两个维度参与,每个头关注整个序列的不同信息
# 为什么先按(batch_size, seq_len, n_heads, dim // n_heads)展开再互换1, 2的维度?
# 因为 view 的展开范式直接把输入全部排开,然后按要求构造,只有上述要求能实现将每个头对应的部分取出
xq = xq.view(batch_size, seq_len, self.n_heads, self.head_dim)
xk = xk.view(batch_size, seq_len, self.n_heads, self.head_dim)
xv = xv.view(batch_size, seq_len, self.n_heads, self.head_dim)

xq = xq.transpose(1,2)
xk = xk.transpose(1,2)
xv = xv.transpose(1,2)

# 注意力计算: Q * K^T / sqrt(d_k)
# (batch_size, n_heads, seq_len, head_dim) x (batch_size, n_heads, head_dim, seq_len)
# -> (batch_size, n_heads, seq_len, seq_len)
scores = torch.matmul(xq, xk.transpose(-2,-1)) / math.sqrt(self.head_dim)

# 掩码注意力计算
if self.is_causal:
assert hasattr(self, 'mask')
# 这里截取到序列长度,因为有些序列可能比 max_seq_len 短
scores = scores + self.mask[:, :, :seq_len, :seq_len]

# 计算softmax, shape; (batch_size, n_heads, seq_len, seq_len)
scores = F.softmax(scores.float(), dim = -1)
# 注意力 dropout
scores = self.attn_dropout(scores)

# V * scores
# (batch_size, n_heads, seq_len, seq_len) × (batch, n_heads, seq_len, head_dim) → (batch, n_heads, seq_len, head_dim)
output = torch.matmul(scores, xv)

# 恢复时间维度并合并头
# 将多头的结果拼接起来,先交换维度再拼接
# contiguous 函数用于开拍一块新内存,因为 Pytorch 先 transpose 再 view 会报错,
# 因为 view 直接基于底层存储得到, 然而 transporse 并不会改变底层存储,因此需要额外存储
output = output.transpose(1,2).contiguous().view(batch_size, seq_len, -1)

# 最终投影回残差流
output = self.wo(output)
output = self.resid_dropout(output)

return output

注:n_embed(嵌入维度)和args.dim(模型维度)可能不相等,须在多头注意力模块外部做线性变换,保证进入模块的输入维度 = args.dim;

疑问?

为什么要对注意力机制进行 dropout ?
原因解释
防止过拟合注意力机制可能过度依赖某些特定的位置关系,dropout强制模型学习更鲁棒的注意力模式
正则化随机"丢弃"部分注意力连接,避免模型记住训练数据的特定模式
促进稀疏性鼓励模型学习更加分散的注意力分布,而不是集中在少数几个位置
  • 对注意力概率分布随机置零,被dropout的位置在输出中不会贡献任何信息;
多头注意力中的“矩阵内积再拼接等同于拼接矩阵再内积”是什么意思?

在多头注意力中:

  • 原始 Q/K/V 的维度:[batch_size, seq_len, d_model]d_model 是总维度,比如 512)
  • 多头拆分:把 d_model 拆成 h 个头,每个头的维度d_k = d_model/h(比如 h=8,d_k=64)
  • 每个头需要 3 个参数矩阵:WiQW_i^QWiKW_i^KWiVW_i^V(维度都是[d_model, d_k]

思路1:逐头计算再拼接(矩阵内积→拼接)

这种方法需要循环 h 次,计算每个头然后再拼接:

  1. 对第 1 个头:Q1=QW1QQ1 = Q · W_1^QK1=KW1KK1 = K · W_1^KV1=VW1VV1 = V · W_1^V → 计算head1=Attention(Q1,K1,V1)head1 = Attention(Q1,K1,V1)

  2. 对第 2 个头:Q2=QW2QQ2 = Q · W_2^QK2=KW2KK2 = K · W_2^KV2=VW2VV2 = V · W_2^V → 计算head2=Attention(Q2,K2,V2)head2 = Attention(Q2,K2,V2)

    1. … 直到第 h 个头;
    2. 最后把head1,head2,...headhhead1, head2,...headh 在最后一维拼接 →Concat(head1,...,headh)Concat(head1,...,headh)

思路 2:拼接矩阵再内积(拼接→矩阵内积)

这是一种并行写法,无循环,一次矩阵运算搞定:

  1. 把 h 个WiQW_i^Q 拼接成一个大矩阵WQW^QWQ=[W1Q,W2Q,...,WhQ]W^Q = [W_1^Q, W_2^Q, ..., W_h^Q] (维度[d_model, d_model],因为 h * d_k = d_model
  2. 一次矩阵内积完成所有头的QQ 变换:Qall=QWQQ_{all} = Q · W^Q(维度[batch_size, seq_len, d_model]
  3. QallQ_{all}在最后一维拆成 h 个头:Q1,Q2,...QhQ1, Q2,...Qh(每个维度[batch_size, seq_len, d_k]
  4. 对 K/V 做同样的操作,然后并行计算所有头的注意力,最后拼接结果;

等价性

假设 h=2 个头,d_model=128,d_k=64:

  • 逐头计算:Q1=QW1QQ1 = Q·W1^Q(64 维)、Q2=QW2QQ2 = Q·W2^Q(64 维)→ 拼接后[Q1,Q2][Q1, Q2](128 维)
  • 拼接矩阵计算:WQ=[W1Q,W2Q]W^Q = [W1^Q, W2^Q](128 维)→QWQ=Q[W1Q,W2Q]=[QW1Q,QW2Q]=[Q1,Q2]Q·W^Q = Q·[W1^Q, W2^Q] = [Q·W1^Q, Q·W2^Q] = [Q1, Q2]
多头注意力为什么要拆分 d_model?

这个问题其实源于 Transformer 设计的核心目标:用更低的计算成本,实现更细粒度的注意力建模

原始自注意力(单头):Q/K/V 维度都是 d_model,计算注意力的复杂度是O(n2dmodel)O(n^2⋅d_model)(n 是序列长度)。

多头注意力:把 d_model 拆成 h 个dk=dmodel/hd_k=d_{model}/h,每个头用 d_k 维度的 Q/K/V 计算注意力,最后拼接还原为 d_model

  1. 拆分 d_model 能够保证复杂度不暴涨

如果不拆分 d_model,直接做 h 个头的计算(每个头都用 d_model 维度),总复杂度会变成hO(n2dmodel)h⋅O(n^2⋅d_{model}),是原来的 h 倍,计算成本极高。

而拆分 d_model 后,每个头的复杂度为O(n2dk)=O(n2dmodel/h)O(n^2⋅d_k)=O(n^2⋅d_{model}/h),h 个头的总复杂度为hO(n2dmodel/h)=O(n2dmodel)h⋅O(n^2⋅d_{model}/h)=O(n^2⋅d_{model}), 和单头复杂度完全一致,但却获得了 h 倍的表达能力。

  1. d_model 代表语义特征的维度,拆分符合注意力的一个建模逻辑

注意力的本质是计算 Query 和 Key 的相似度,而 Q/K 的维度 d_model 代表的是语义特征的维度(每个 token 的语义向量),拆分 d_model 相当于把一个完整的语义向量拆成多个 “子语义向量”,每个子向量聚焦一个语义角度(比如语法、语义等),更符合人类理解语言的多维度的一个特性。

举个通俗的例子:

  • 把一个 token 的语义向量(d_model=512)比作 “一个人的完整信息”(身高、体重、年龄、职业等);
  • d_model 拆分成 8 个头(d_k=64),相当于让 8 个 “分析师” 分别关注不同的信息维度(分析师 1 看身高 / 体重、分析师 2 看年龄 / 职业、分析师 3 看兴趣 / 爱好……);
  • 最后将 8 个分析师的结论拼接,得到一个人更全面的信息。
  1. 维度对齐的要求

多头注意力的最后一步是拼接所有头的输出,再通过一个线性层 WO(维度 dmodel×dmodel)还原为 d_model 维度,为了让拼接后的维度刚好等于 d_model,因此必须满足h×dk=dmodelh × d_k=d_{model}, 因此拆分 d_model 是唯一能保证 “拼接后维度和输入维度一致” 的方式,避免维度不匹配导致的矩阵运算错误。

简单来说,拆分 d_model 是效率和效果的最优解,也就是花和单头一样的计算成本,去拿到 h 倍的语义建模能力。

为什么要对输入进行 Wq, Wk, Wv 的线性变换?

解耦 Q/K/V 的语义空间

函数的输入 q, k, v 是 x,原始输入x直接作为 Q/K/V,三者共享同一语义空间,注意力的表达能力会被限制(相当于强制 Q=K=V)

通过 Wq/Wk/Wv 三个独立的线性层,让 Q/K/V 映射到不同的子语义空间

  • Q(查询):聚焦 “当前 token 需要找什么信息”;
  • K(键):聚焦 “其他 token 能提供什么信息”;
  • V(值):聚焦 “其他 token 的核心信息是什么”;

这种解耦是注意力能学到有效关联的关键。

最终投影回残差流是什么意思?为什么这里有 Wo?

残差流

Transformer 的核心设计之一是残差连接(Residual Connection),公式为:Output=Attention(x)+xOutput=Attention(x)+x

  • x:注意力模块的原始输入(token 嵌入向量);
  • Attention(x):注意力模块的输出;

两者相加的前提是维度完全一致(否则无法做逐元素加法);

维度对齐:把多头拼接后的输出映射回残差流维度

语义融合:整合多头信息

多头注意力的计算过程中:

  1. 每个头的输出维度是[batch, n_heads, seq_len, head_dim]
  2. 转置 + 拼接后,维度变为[batch, seq_len, n_heads*head_dim](即[batch, seq_len, args.dim]);

此时的输出是 “多个头的简单拼接”,特征之间没有交互,Wo 是一个可学习的线性层,作用是:

  • 把 “多头拼接的特征” 映射回 “残差流维度”,虽然维度数值相等,但语义空间需要融合;
  • 学习多头特征的加权组合,让模型自主选择哪些头的信息更重要。