MHA MQA GQA 对比

MHA Multi-Head Attention,MQA Multi-Query Attention,GQA Group-Query Attention。在 transformers 中,主要体现在 config.json 里的 num_key_value_heads 设置上。

QKV shape 如下:

  • Query: $[B, L, D_{\text{model}}]$ -> $[B, L, N_\text{heads}, D_\text{head}]$ -> $[B, N_\text{heads}, L, D_\text{heads}]$
  • Key & Value: $[B, L, D_\text{kv}]$ -> $[B, N_\text{kv}, L, D_\text{head}]$

其中,$B$ 表示 batch size,$L$ 表示 sequence length,$D_\text{model}$ 为 hidden size,$N_\text{heads}$ 为 num heads,即多头注意力里的头数。

reshape 完成后,按照这个顺序去做 Self-Attention $Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$。(K 需要 transpose,以满足 self-attention 的公式)

num_key_value_headsnum_heads 时,上述计算是标准的 MHA;为 1 时,是 MQA,所有的 Query 共享一个 KV;介于二者之间,是 GQA。因此 MHA 和 MQA 可以视为 GQA 的特殊情况。对于非相等情况下的 MQA 和 GQA,调用 repeat_kv 函数复制 N-1 份 KV,就可以按照 MHA 的代码计算了。

由于 kv 减小了,因此 self-attention 前的 proj matrix 减小了,总的计算量有所下降。更重要的是由于 KV 减小,保存和访问 kv cache 所需的带宽也减小了,kv cache 减小到 MHA 的 1/group

MHA vs GQA vs MQA

上面这张图解释了三个代码的不同,此时 gqa 的 group size 为 2,mqa 的 group size 为 8。

MHA to MQA by mean-pooling

上图解释了如何将一个 MHA 转为 MQA,即使用一个 mean pool。

python 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

# init时k和v用self.num_key_value_heads * self.head_dim初始化,当self.num_key_value_heads小于self.num_heads时,参数量变少
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)

# forward时,通过repeat_kv方法,将hidden states 从(batch, num_key_value_heads, seqlen, head_dim) 变成 (batch, num_attention_heads, seqlen, head_dim),相当于是复制了self.num_key_value_groups份
self.num_key_value_groups = self.num_heads // self.num_key_value_heads

key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

ONNX 展示

qwen2.5

上图展示了 qwen2.5 的 GQA 在 ONNX 里的可视化结果。参数为:num_kv_heads = 4,hidden_size = 3584,head_dim = 128,seqlen = 2048。因此每个 kv cache 需要 repeat 3584 / 128 / 4 = 7 次。kv cache 的保存从 2048*28*128 下降到 2048*4*128,降低到 1/7,即 group size 为 7。

如何不使用 repeat_kv 方法

如果要跳过 repeat_kv,那么必须使用 slice 或者 gather 操作从 QKV 中提取出数据来。默认状态下 Q 的 shape 为:[B, N_head, L, D_head],其中 B 为 batch size,N_head 为 num_head,L 为 seqlen,D_head 为 head_dim;KV 的 shape 为:[B, N_kv, L, D_head],N_kv 为 num_kv_heads。per-head 和 per-group 是不同的 index 方式。计算矩阵乘法要求的是 Q 和 K 的后两维度,即[L, D_head],索引操作主要在 N_head 和 N_kv 上。

  • 优点:减少了 tensor 大小,能够支持较小的 local memory / 片上内存
  • 缺点:推理框架循环展开后计算图复杂

per-group

将每个 group 单独计算。将 N_head 切分为 N_group * N_kv,从上面的计算方式来看,这还是一个 strided-slice,即 Q[:, i::n_group],QKV shape 均为 [B, N_kv, L, D_head]。per-group 需要 stride-slice,合并输出时也复杂一点(stirided-scatter 或者 多个transpose+concat)

0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
Query
Query
0
0
1
1
2
2
3
3
Key
Key
0
0
1
1
2
2
3
3
Value
Value
Per-Group
Per-Group
0
0
2
2
4
4
6
6
0
0
1
1
2
2
3
3
0
0
1
1
2
2
3
3
1
1
3
3
5
5
7
7
0
0
1
1
2
2
3
3
0
0
1
1
2
2
3
3
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
0
0
2
2
4
4
6
6
1
1
3
3
5
5
7
7
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
Query
Query
Re-arrange
Re-arrange

Per-Group

per-group 按组分开,一次取出 num_kv_heads 个 query,跟 num_kv_heads 去做计算

Per-Group…
Text is not SVG - cannot display

per-head

将每个 kv-head 进行计算,连续地切分 Q 和 KV。即 Q[:, i*n_group:(i+1)*n_group] * K[:, i]。由于 KV 此时变成 [B, 1, L, D_head],而 Q 为 [B, N_group, L, D_head] ,pytorch 会自动进行广播,得到一个 [B, N_group, L, D_head] 的矩阵。结果进行 concat 即可。

0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
Query
Query
0
0
1
1
2
2
3
3
Key
Key
0
0
1
1
2
2
3
3
Value
Value
Per-Head
Per-Head
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
Query
Query
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
0
0
1
1
2
2
3
3
0
0
1
1
2
2
3
3
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
Concat
Concat

Per-Head

per-head 每次计算单个 head 的全部计算,按顺序取出 Query 和 Head 的数据

Per-Head…
Text is not SVG - cannot display

参考

LLM中 Attention 的实现方式汇总