背景
在算能的 LLM-TPU 里有对 Qwen2.5-VL 视觉模型的导出,但是并没有增加对窗口注意力的支持,因此跟 Qwen2.5-VL 原生相比,性能下降较大。
UPDATE: 算能官方已经支持 qwen2.5-vl 的窗口注意力,具体实现参考 tpu-mlir/llm。
qwen2.5-vl 视觉编码器的注意力机制
qwen2.5-vl 中有两种不同的注意力机制:
- 全局注意力:用于处理全图特征,但是计算量较大,计算量增长随着图像尺寸的增加而呈平方增长,因此只在 4 个特殊的 layer 使用。
- 窗口注意力:只计算局部区域的注意力,计算量较小,适用于大多数层。
qwen2.5-vl 的视觉编码器注意力计算如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| class Qwen2_5_VLVisionFlashAttention2(nn.Module): ...
def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) cos, sin = position_embeddings q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) q = q.squeeze(0) k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.proj(attn_output) return attn_output
|
FlashAttention2 的窗口注意力实现
qwen2.5-vl 默认使用 FlashAttention2 作为注意力实现。FlashAttention2 原生支持窗口注意力,如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| def flash_attn_varlen_func( q: "Tensor[total_q, nheads, headdim]", k: "Tensor[total_k, nheads_k, headdim]", v: "Tensor[total_k, nheads_k, headdim]", cu_seqlens_q: "Tensor[batch_size + 1], torch.int32", cu_seqlens_k: "Tensor[batch_size + 1], torch.int32", max_seqlen_q: int, max_seqlen_k: int, dropout_p: float = 0.0, softmax_scale: float = None, causal: bool = False, window_size: tuple = (-1, -1), softcap: float = 0.0, alibi_slopes = None, deterministic: bool = False, return_attn_probs: bool = False, block_table = None, ): ...
|
FlashAttention2 的 flash_attn_varlen_func
函数中,q 的 shape 为 $L \times H \times D$,其中 $L$ 是 query 的长度,$H$ 是 head 数量,$D$ 是每个 head 的维度。k 和 v 的 shape 类似。通过设置不同的 cur_seqlens_q
和 cu_seqlens_k
,可以实现对一个序列实现让 Attention 只关注 $[seqlens_i,\ seqlens_{i+1}]$ 的范围内的 token,间接实现了窗口注意力。
为什么只 FlashAttention 设置成只关注两个相邻 cur_seqlens
之间的 token 呢?因为实际上传入 q 的 tensor 是变长的,如果不设置为变长,则经过 padding 之后的 q 实际上为 $B \times L^\prime \times H \times D$,其中 $L^\prime$ 是最长序列的长度,$B$ 是 batch size,通过将 $B$ 个不等长的序列拼接成一个大 tensor,FlashAttention2 可以在计算时只关注每个序列的实际长度。如下图,
因为 cur_seqlens_q
作为索引表示的是每个可变长度的索引值,每两个索引之间包含的就是一个独立 batch 的数据,所以也就不能关注到 cur_seqlens_q
之外的信息。
从 padding 成 batch 改为 cur_seqlens_q
的方式,还可以节约 padding 带来的显存开销,所以 FlashAttention 也可以在更大的 batch size 下运行。
使用 padding + EagerAttention 实现窗口注意力
qwen2.5-vl 也提供了没有安装 FlashAttention 的情况下的窗口注意力实现,使用了 padding + EagerAttention(直接实现 attention) 的方式。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| class Qwen2_5_VLVisionAttention(nn.Module): ...
def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.full( [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output
|
由于 cu_seqlens
等于 [0, 100, 200, 300, 320]
,所以 attention mask 的形状是 [1, 320, 320]
,其中 attention_mask[i, j, k]
表示第 i
个 batch 中的第 j
个 token 和第 k
个 token 之间的注意力权重。但是这种方式导致窗口外的 token 之间的注意力权重也被计算了出来,虽然在后续的 softmax 中被归一化为 0,但仍然会带来不必要的计算开销。