Qwen2.5-VL 中视觉模型的窗口注意力机制

背景

在算能的 LLM-TPU 里有对 Qwen2.5-VL 视觉模型的导出,但是并没有增加对窗口注意力的支持,因此跟 Qwen2.5-VL 原生相比,性能下降较大。

UPDATE: 算能官方已经支持 qwen2.5-vl 的窗口注意力,具体实现参考 tpu-mlir/llm

qwen2.5-vl 视觉编码器的注意力机制

qwen2.5-vl 中有两种不同的注意力机制:

  1. 全局注意力:用于处理全图特征,但是计算量较大,计算量增长随着图像尺寸的增加而呈平方增长,因此只在 4 个特殊的 layer 使用。
  2. 窗口注意力:只计算局部区域的注意力,计算量较小,适用于大多数层。

qwen2.5-vl 的视觉编码器注意力计算如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
...

def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output

FlashAttention2 的窗口注意力实现

qwen2.5-vl 默认使用 FlashAttention2 作为注意力实现。FlashAttention2 原生支持窗口注意力,如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def flash_attn_varlen_func(
q: "Tensor[total_q, nheads, headdim]",
k: "Tensor[total_k, nheads_k, headdim]",
v: "Tensor[total_k, nheads_k, headdim]",
cu_seqlens_q: "Tensor[batch_size + 1], torch.int32",
cu_seqlens_k: "Tensor[batch_size + 1], torch.int32",
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = 0.0,
softmax_scale: float = None,
causal: bool = False,
window_size: tuple = (-1, -1), # -1 means infinite context window
softcap: float = 0.0, # 0.0 means deactivated
alibi_slopes = None,
deterministic: bool = False,
return_attn_probs: bool = False,
block_table = None,
):
...

FlashAttention2 的 flash_attn_varlen_func 函数中,q 的 shape 为 $L \times H \times D$,其中 $L$ 是 query 的长度,$H$ 是 head 数量,$D$ 是每个 head 的维度。k 和 v 的 shape 类似。通过设置不同的 cur_seqlens_qcu_seqlens_k,可以实现对一个序列实现让 Attention 只关注 $[seqlens_i,\ seqlens_{i+1}]$ 的范围内的 token,间接实现了窗口注意力。

为什么只 FlashAttention 设置成只关注两个相邻 cur_seqlens 之间的 token 呢?因为实际上传入 q 的 tensor 是变长的,如果不设置为变长,则经过 padding 之后的 q 实际上为 $B \times L^\prime \times H \times D$,其中 $L^\prime$ 是最长序列的长度,$B$ 是 batch size,通过将 $B$ 个不等长的序列拼接成一个大 tensor,FlashAttention2 可以在计算时只关注每个序列的实际长度。如下图,

SeqLen_0
SeqLen_0
SeqLen_1
SeqLen_1
SeqLen_2
SeqLen_2
SeqLen_3
SeqLen_3
SeqLen_0
SeqLen_0
contiguous
contiguous
SeqLen_1
SeqLen_1
SeqLen_2
SeqLen_2
SeqLen_3
SeqLen_3
Padding To Batch
Padding To Batch
Concat As One Query
Concat As One Query
Require using cur_seqlens_q
Require using cur_seqlens_q
Text is not SVG - cannot display

因为 cur_seqlens_q 作为索引表示的是每个可变长度的索引值,每两个索引之间包含的就是一个独立 batch 的数据,所以也就不能关注到 cur_seqlens_q 之外的信息。

从 padding 成 batch 改为 cur_seqlens_q 的方式,还可以节约 padding 带来的显存开销,所以 FlashAttention 也可以在更大的 batch size 下运行。

使用 padding + EagerAttention 实现窗口注意力

qwen2.5-vl 也提供了没有安装 FlashAttention 的情况下的窗口注意力实现,使用了 padding + EagerAttention(直接实现 attention) 的方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Qwen2_5_VLVisionAttention(nn.Module):
...

def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0

q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output

由于 cu_seqlens 等于 [0, 100, 200, 300, 320],所以 attention mask 的形状是 [1, 320, 320],其中 attention_mask[i, j, k] 表示第 i 个 batch 中的第 j 个 token 和第 k 个 token 之间的注意力权重。但是这种方式导致窗口外的 token 之间的注意力权重也被计算了出来,虽然在后续的 softmax 中被归一化为 0,但仍然会带来不必要的计算开销。


Qwen2.5-VL 中视觉模型的窗口注意力机制
http://hebangwen.github.io/2025/04/29/window-attn-in-qwen2-5vl/
作者
何榜文
发布于
2025年4月29日
许可协议