Chunkwise 并行算法 —— 线性注意力

为了在现代硬件上实现高效训练,Mamba2 和 DeltaNet 都通过分块并行(Chunkwise Parallel)策略将 $O(L)$ 的线性递归改写为矩阵迭代形式。这种变换的核心思想是将长序列划分为大小为 $C$ 的块,利用矩阵乘法(Matmul)处理块内计算,通过隐藏状态传递块间信息。

1. Mamba2:状态空间二元性(SSD)的矩阵迭代推导

Mamba2 的核心贡献在于证明了选择性 SSM 实际上等价于一种带有特殊掩码的线性注意力机制。

1.1 从递归到块内状态展开

Mamba2 的状态更新公式(矩阵形式)为:

$$S_t = \alpha_t S_{t-1} + v_t k_t^T \in \mathbb{R}^{d_v \times d_k}, \quad o_t = S_t q_t \in \mathbb{R}^{d_v}$$

其中 $\alpha_t \in (0,1)$ 是数据相关的衰减因子。定义 $\gamma_j = \prod_{i=1}^j \alpha_i$ 为全局累积衰减,是一个标量因子。

如果只考虑端侧 CPU/GPU 推理,串行计算 LinearAttention 也可以接受,因为还有 Head 可以用于并行计算。sdot 指令集单 CPU 的计算密度为 $\frac{16 \ \text{MAC} \ \cdot \ 2 \ \text{op}/\text{MAC}}{128 \ \text{bit}} \cdot 2 = 4 \ \text{op}/\text{byte}$,4 核心 CPU 可以支持 16 个 Head 同时计算。

通过循环展开,Mamba 的自注意力输出 $o_t$ 可以使用矩阵表示为:

$$o_t = \sum_{i=1}^t \left( \frac{\gamma_t}{\gamma_i} v_i k_i^T \right) q_t = \sum_{i=1}^t v_i \left( \frac{\gamma_t}{\gamma_i} k_i^T q_t \right) \in \mathbb{R}^{d_v}, \quad \mathbf{O} = ((\mathbf{Q} \mathbf{K}^T) \odot \Gamma)\mathbf{V} \in \mathbb{R}^{L \times d_v} $$

其中,$\Gamma$ 为衰减感知掩膜矩阵。

1.2 矩阵迭代形式

对于一个大小为 $C$ 的块 $[t]$,假设初始状态为 $S_{[t]}$,我们可以展开块内的第 $r$ 个时刻的状态:

$$S_{[t]}^r = (\prod_{i=1}^r \alpha_{[t]}^i) S_{[t]} + \sum_{i=1}^r (\prod_{j=i+1}^r \alpha_{[t]}^j) v_{[t]}^i k_{[t]}^{iT} \in \mathbb{R}^{d_v \times d_k}$$

为了简化,令 $\gamma_{[t]}^r = \prod_{j=1}^r \alpha_{[t]}^j$。块末尾的最终状态 $S_{[t+1]}$ 为:

$$S_{[t+1]} = \gamma_{[t]}^C S_{[t]} + \sum_{i=1}^C \frac{\gamma_{[t]}^C}{\gamma_{[t]}^i} v_{[t]}^i k_{[t]}^{iT} = \gamma_{[t]}^C S_{[t]} + \sum_{i=1}^C v_{[t]}^i \left( \frac{\gamma_{[t]}^C}{\gamma_{[t]}^i} k_{[t]}^{iT} \right) \in \mathbb{R}^{d_v \times d_k}$$

引入衰减向量 $\overrightarrow{k_{[t]}^r} = \frac{\gamma_{[t]}^C}{\gamma_{[t]}^r} k_{[t]}^r$ 和块衰减矩阵 $\overrightarrow{S_{[t]}} = \gamma_{[t]}^C S_{[t]}$,状态更新可写为高效的矩阵乘法:

$$S_{[t+1]} = \overrightarrow{S_{[t]}} + V_{[t]}^T \vec{ K _ { [t] } } \in \mathbb{R}^{d_v \times d_k}, \quad \mathbf{O}_{ [t] } = \overleftarrow{\mathbf{Q}_{ [t] }}\mathbf{S}_{ [t] }^T + ( \mathbf{Q}_{ [t] } \mathbf{K}_{ [t] }^T \odot \mathbf{\Gamma}_{ [t] } ) \mathbf{V}_{ [t] } \in \mathbb{R}^{C \times d_v}$$

块内输出 $O_{[t]}$ 由两部分组成:

  1. 初始状态贡献:$\overleftarrow{Q_{[t]}} S_{[t]}^T$(其中 $\overleftarrow{q_{[t]}^r} = \gamma_{[t]}^r q_{[t]}^r$)。
  2. 块内因果贡献:$(Q_{[t]} K_{[t]}^T \odot \Gamma_{[t]}) V_{[t]}$,其中 $\Gamma_{[t]}$ 是衰减掩码矩阵 。

1.3 与 self-attention 的对比

注意到,矩阵迭代形式的 Linear-Attention 使用的状态公式为 $\text{LAttn}(Q, K, V) = (Q_{[t]} K_{[t]}^T \odot \Gamma_{[t]}) V_{[t]}$,Self-Attention 为 $\text{SelfAttn}(Q, K, V) = \text{softmax}(QK^T \odot \Gamma)V$,这两种格式下二者的时间复杂度均为 $O(L^2 d)$,即常说的 Self-Attention 的时间复杂度与序列的平方成正比,而此时 Linear-Attention 也与序列的平方成正比。

在 Linear Attention 中(为了简化,暂时忽略公式中的衰减项 $\Gamma$),核心计算形式可以抽象为:

$$\text{LAttn}(Q, K, V) = (Q K^T) V$$

根据矩阵乘法的结合律,我们有两种计算顺序:

  1. 方案 A (类似 Self-Attention):
    1. 先计算 $(Q K^T), Q \in \mathbb{R}^{L \times d}, K^T \in \mathbb{R}^{d \times L}$。
    2. 相乘结果是一个 $L \times L$ 的巨大矩阵。时间复杂度是 $O(L^2 d)$。
  2. 方案 B (Linear 模式):
    1. 先计算 $(K^T V), K^T \in \mathbb{R}^{d \times L}, V \in \mathbb{R}^{L \times d}$。
    2. 相乘结果是一个 $d \times d$ 的小矩阵,即 State $S$。时间复杂度是 $O(L d^2)$。

由于在长文本中 $L \gg d$(例如 $L=8192, d=128$),方案 B 的复杂度随序列长度 $L$ 线性增长。将 $S$ 视为一个“累积状态”,每输入一个新 Token,只需更新一次这个 $d \times d$ 的矩阵,这就是 RNN 风格的常数级推理。

在 Self-Attention 中,由于非线性函数 $\text{softmax}$ 的存在,$\text{SelfAttn}(Q, K, V) = \text{softmax}(QK^T \odot \Gamma)V$ 无法变形为 Linear 模式进行计算,因此 Self-Attention 总是 $O(L^2 d)$ 的时间复杂度。

在 Linear-Attention 训练时,如果完全采用 RNN 模式(纯线性扫描),虽然是 $O(L)$,但无法利用 GPU 的张量并行能力,速度极慢。因此引入了 Chunkwise 方案:

  1. 分块逻辑: 将长度为 $L$ 的序列分成 $L/C$ 个块。
  2. 块内(Intra-chunk): 在每一个大小为 $C$ 的块内部,我们依然采用类似 Self-Attention 的 $QK^T$ 方案。这里的局部复杂度是 $O((L/C) \cdot C^2) = O(LC)$。
  3. 块间(Inter-chunk): 块与块之间通过传递 State(你公式中的 $S$)来通信。每个块生成的 $S$ 会作为下一个块的初始状态。

最终结果:Chunkwise Parallel 不是简单的 $O(L^2 d)$。如果 $C$ 取常数(如 64 或 128),那么整体复杂度依然是关于 $L$ 的线性函数。它避开了全局 $L \times L$ 矩阵,只计算局部 $C \times C$ 矩阵。

Flash Linear Attention 为 Prefill 和 Decode 两个阶段设置了不同的内核:

  1. prefill 阶段:使用 mamba_chunk_scan_combined 进行分块扫描
  2. decode 阶段:使用 selective_state_update 进行序列式的状态更新,输入的形状中没有 seq_len 参数

2. DeltaNet:基于 WY 表示法的矩阵迭代推导

DeltaNet 的更新规则涉及 Householder 变换,其形式比 Mamba2 更复杂,无法直接通过简单的累积和来并行。

2.1 从增量规则到 Householder 形式

DeltaNet 的递归公式如下:

$$S_t = S_{t-1}(I - \beta_t k_t k_t^T) + \beta_t v_t k_t^T$$

其中 $(I - \beta_t k_t k_t^T)$ 是广义 Householder 矩阵。这个公式可以抽象为 $S_t = S_{t-1} M_t + X_t$。

2.2 WY 表示法展开

数值线性代数证明,多个 Householder 矩阵的乘积仍具有低秩结构。对于块 $[t]$,块内状态可表示为:

$$S_{[t]}^r = S_{[t]} P_{[t]}^r + H_{[t]}^r$$

其中 $P$ 是累积变换矩阵,$H$ 是累积写入矩阵。根据 WY 表示法

  • $P_{[t]} = I - W_{[t]}^T K_{[t]}$
  • $H_{[t]} = U_{[t]}^T K_{[t]}$

2.3 引入 UT 变换计算 $W$ 和 $U$

为了避免 $O(C)$ 的序列计算,利用 UT 变换 将 $W$ 和 $U$ 转化为矩阵运算:

  1. 计算辅助矩阵 $T_{[t]} = [I + \text{strictLower}(\text{diag}(\beta_{[t]}) K_{[t]} K_{[t]}^T)]^{-1} \text{diag}(\beta_{[t]})$。
  2. 通过 $T$ 直接得到块内权重:$W_{[t]} = T_{[t]} K_{[t]}$ 和 $U_{[t]} = T_{[t]} V_{[t]}$。

虽然矩阵求逆比较困难,但是 $I + \text{strictLower}(\text{diag}(\beta_{[t]}) K_{[t]} K_{[t]}^T)$ 是标准的下三角矩阵,可以通过下面公式求解(其中 $N$ 为严格下三角矩阵,即对角线以上全部为 0):

$$\begin{align}& N^n = 0 \\\Rightarrow & I - N^n = I \\\Rightarrow & (I-N)(I - N + N^2 - N^3 + \cdots + (-1)^{n-1}N^{n-1}) = I \\\Rightarrow & Q = (I-N)^{-1} = I - N + N^2 - N^3 + \cdots + (-1)^{n-1}N^{n-1}\end{align}$$

更进一步,由于 $N$ 是一个严格下三角矩阵,因此每一个结果 $q_{i,j}$ 只取决于 $q_{\le i, \le j }$ 的部分,我们可以得到如下的迭代公式:

$$Q_{[t, \le t]} = Q_{[t, \le t ]} + \sum_{row} \text{broadcast}(Q_{[t, \le t]}^T)Q_{[\le t, \le t ]}, \quad Q = -N$$

  1. 按行求和的原因:$Q_{[t, \le t]}^T$ 已经把行中的每一列元素转置到了每一行,所以结果需要按行求和
  2. 初始条件为负号:保证每次相乘之后,可以出现正负交替的情况

2.4 矩阵迭代形式

最终,DeltaNet 的块间状态更新方程为:

$$S_{[t+1]} = S_{[t]} + (U_{[t]} - W_{[t]} S_{[t]}^T)^T K_{[t]}$$

块内输出方程为:

$$O_{[t]} = Q_{[t]} S_{[t]}^T + (Q_{[t]} K_{[t]}^T \odot M) (U_{[t]} - W_{[t]} S_{[t]}^T)$$

其中 $M$ 为标准的下三角因果掩码。

3. 算法对比总结

特性 Mamba2 (SSD) DeltaNet (WY Rule)
递归本质 带有标量衰减的选择性线性 RNN 带有矩阵擦除的线性回归更新
块内算子 衰减卷积 + 线性注意力 UT 变换 + 线性注意力
状态传递 $S_{[t+1]} = \text{decay} \cdot S_{[t]} + \text{KV_sum}$ $S_{[t+1]} = S_{[t]} \cdot \text{Householder_Prod} + \text{Weighted_V}$
硬件优化 重点在于计算累积衰减向量 重点在于计算 $T$ 矩阵以物化 $W$ 和 $U$

结论:Mamba2 的矩阵迭代形式依赖于对角阵约束($A=\alpha I$)来简化并行,而 DeltaNet 通过 WY 表达法 保持了更复杂的非对角交互。Gated Delta Net 则在此基础上,通过在 $T$ 矩阵计算中引入衰减掩码 $\Gamma$,将 Mamba 的门控特性无缝集成到了增量规则的矩阵形式中 。


Chunkwise 并行算法 —— 线性注意力
http://hebangwen.github.io/2026/03/06/chunkwise-parallel-algo/
作者
何榜文
发布于
2026年3月6日
许可协议