MNN 自定义算子,以 AnyNet 为例

1. 介绍

本文主要介绍如何在 MNN 中添加自定义算子,以 AnyNet 为例。AnyNet 添加了一个自定义算子,虽然可以用 pytorch 表达,并且导出 ONNX 和 MNN,但是节点过多,可视化工具无法很好的展示,因此本文尝试将 AnyNet 中的自定义算子添加到 MNN 中。

自定义 MNN 算子需要先将 pytorch 计算逻辑导出成 ONNX 节点,并且由 MNN 来解释。这需要首先将 pytorch 计算包装成 torch.autograd.Function,然后通过 torch.onnx.export 导出 ONNX 节点。然后通过 MNN 的 MNN::OpConverter 来解释 ONNX 节点。

2. AnyNet

AnyNet

AnyNet 自定义了一个 SPNet 算子,并且使用 CUDA 实现了这个算子。SPN的结构,分为两种,分别是单路连接和三路连接:

spnet.png

实际上就是通过一个权重矩阵,上一个点迭代地给下一个点权重,对于单路连接和三路连接,有:

$$
\begin{align}
h_{k,t} &= (1 - p_{k,t}) \cdot x_{k,t} + p_{k,t} \cdot h_{k,t-1} \
h_{k,t} &= (1 - \sum_{k \in \mathbb{N}} p_{k,t}) x_{k,t} + \sum_{k \in \mathbb{N}}p_{k,t} h_{k,t-1}
\end{align}
$$

$p$是权重,$x$是当前点的值,$h$是传播后的值,$k$是行索引,$t$是列索引。AnyNet使用的是三路连接中的left to right,重复传播(recurrent propagation)要求下一个点的值依赖于上三个点(也就是左边三个点)的值。

下面是 CUDA 实现和 pytorch 实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

__global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) {
CUDA_1D_KERNEL_LOOP(index, count) {

int hc_count = height * channels;

int n,c,h,w;
int temp=index;
w = T;
n = temp / hc_count;
temp = temp % hc_count;
c = temp / height;
temp = temp % height;
h = temp;


float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);

float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
float h1_minus1 = g_data_1 * h_minus1_data_1;

float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
float h2_minus1 = g_data_2 * h_minus1_data_2;

float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
float h3_minus1 = g_data_3 * h_minus1_data_3;

float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;

float h_data = x_hype + h_hype;

set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);

}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class GateRecurrent(nn.Module):
def __init__(self):
super().__init__()

def forward(self, X, G1, G2, G3):
width = X.size(3)
H = torch.zeros_like(X)
for t in range(width):
g1 = G1[..., t]
g2 = G2[..., t]
g3 = G3[..., t]

h1 = F.pad(H[:, :, :-1, t-1], (1, 0))
h2 = H[:, :, :, t-1]
h3 = F.pad(H[:, :, 1:, t-1], (0, 1))

g = g1 + g2 + g3
x = X[..., t]
H[..., t] = (1 - g) * x + (h1 * g1 + h2 * g2 + h3 * g3)

return H

3. 自定义 ONNX 算子

待续


MNN 自定义算子,以 AnyNet 为例
http://hebangwen.github.io/2024/05/21/MNN-OP-AnyNet/
作者
何榜文
发布于
2024年5月21日
许可协议