Self-Attention 算法简介
MHA MQA GQA 对比
MHA Multi-Head Attention,MQA Multi-Query Attention,GQA Group-Query Attention。在 transformers 中,主要体现在 config.json 里的 num_key_value_heads 设置上。
QKV shape 如下:
- Query: $[B, L, D_{\text{model}}]$ -> $[B, L, N_\text{heads}, D_\text{head}]$ -> $[B, N_\text{heads}, L, D_\text{heads}]$
- Key & Value: $[B, L, D_\text{kv}]$ -> $[B, N_\text{kv}, L, D_\text{head}]$
其中,$B$ 表示 batch size,$L$ 表示 sequence length,$D_\text{model}$ 为 hidden size,$N_\text{heads}$ 为 num heads,即多头注意力里的头数。
reshape 完成后,按照这个顺序去做 Self-Attention $Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$。(K 需要 transpose,以满足 self-attention 的公式)
当 num_key_value_heads 为 num_heads 时,上述计算是标准的 MHA;为 1 时,是 MQA,所有的 Query 共享一个 KV;介于二者之间,是 GQA。因此 MHA 和 MQA 可以视为 GQA 的特殊情况。对于非相等情况下的 MQA 和 GQA,调用 repeat_kv 函数复制 N-1 份 KV,就可以按照 MHA 的代码计算了。
由于 kv 减小了,因此 self-attention 前的 proj matrix 减小了,总的计算量有所下降。更重要的是由于 KV 减小,保存和访问 kv cache 所需的带宽也减小了,kv cache 减小到 MHA 的 1/group。

上面这张图解释了三个代码的不同,此时 gqa 的 group size 为 2,mqa 的 group size 为 8。

上图解释了如何将一个 MHA 转为 MQA,即使用一个 mean pool。
python 代码如下:
1 |
|
ONNX 展示

上图展示了 qwen2.5 的 GQA 在 ONNX 里的可视化结果。参数为:num_kv_heads = 4,hidden_size = 3584,head_dim = 128,seqlen = 2048。因此每个 kv cache 需要 repeat 3584 / 128 / 4 = 7 次。kv cache 的保存从 2048*28*128 下降到 2048*4*128,降低到 1/7,即 group size 为 7。
如何不使用 repeat_kv 方法
如果要跳过 repeat_kv,那么必须使用 slice 或者 gather 操作从 QKV 中提取出数据来。默认状态下 Q 的 shape 为:[B, N_head, L, D_head],其中 B 为 batch size,N_head 为 num_head,L 为 seqlen,D_head 为 head_dim;KV 的 shape 为:[B, N_kv, L, D_head],N_kv 为 num_kv_heads。per-head 和 per-group 是不同的 index 方式。计算矩阵乘法要求的是 Q 和 K 的后两维度,即[L, D_head],索引操作主要在 N_head 和 N_kv 上。
- 优点:减少了 tensor 大小,能够支持较小的 local memory / 片上内存
- 缺点:推理框架循环展开后计算图复杂
per-group
将每个 group 单独计算。将 N_head 切分为 N_group * N_kv,从上面的计算方式来看,这还是一个 strided-slice,即 Q[:, i::n_group],QKV shape 均为 [B, N_kv, L, D_head]。per-group 需要 stride-slice,合并输出时也复杂一点(stirided-scatter 或者 多个transpose+concat)
per-head
将每个 kv-head 进行计算,连续地切分 Q 和 KV。即 Q[:, i*n_group:(i+1)*n_group] * K[:, i]。由于 KV 此时变成 [B, 1, L, D_head],而 Q 为 [B, N_group, L, D_head] ,pytorch 会自动进行广播,得到一个 [B, N_group, L, D_head] 的矩阵。结果进行 concat 即可。