Caffe2 中 float32 与 float16 的转换

generated by Copilot and translated by DeepL

这段 C++ 代码是一个名为 cpu_float2half_rn 的函数,用于将单精度浮点数(32 位)转换为半精度浮点数(16 位)。该函数将浮点数作为参数,并返回一个 float16 对象。转换过程涉及多个步骤,包括处理 NaN 和无穷大等特殊情况,从输入中提取符号、指数和尾数,以及四舍五入到最接近的偶数。

  1. 函数首先声明了一个 float16 对象 ret,用于保存转换结果。然后使用 static_assert 来确保无符号 int 和 float 的大小相同。这一点很重要,因为函数使用位操作来处理 float 的二进制表示,这需要将 float 解释为无符号整数。
  2. 函数使用 reinterpret_cast 获得指向 float 二进制表示形式的指针。然后创建该二进制表示的副本,屏蔽符号位,得到浮点数 u 的绝对值。
  3. 函数会处理特殊情况。如果 u 大于 0x7f800000,输入值为 NaN,函数会将结果设置为特殊的 NaN 值并返回。如果 u 大于 0x477fefff 或小于 0x33000001,则输入值为无穷大或零,函数会将结果设置为适当的值并返回。
  4. 函数从浮点数的二进制表示中提取符号、指数和尾数。符号是最高位,指数是接下来的 8 位,尾数是剩余的 23 位。
  5. 函数会调整指数和尾数,以适应较小的半精度格式。如果指数大于 0x70,则从指数中减去 0x70,并将尾数右移 13 位。否则,它会将指数设为 0,设置尾数的最高位,并计算尾数右移的位数。
  6. 函数将尾数四舍五入为最接近的偶数。计算移位后的最小有效位及其以下的位数。如果最小有效位以下的位数大于一半或正好是一半,且最小有效位为 1,则尾数递增。如果尾数溢出,则递增指数并重置尾数。
  7. 函数将符号、指数和尾数合并为半精度格式,并将结果存储在 ret 中。然后返回 ret。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
inline float16 cpu_float2half_rn(float f) {
float16 ret;

static_assert(
sizeof(unsigned int) == sizeof(float),
"Programming error sizeof(unsigned int) != sizeof(float)");

unsigned* xp = reinterpret_cast<unsigned int*>(&f);
unsigned x = *xp;
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;

// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
ret.x = 0x7fffU;
return ret;
}

sign = ((x >> 16) & 0x8000);

// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
ret.x = sign | 0x7c00U;
return ret;
}
if (u < 0x33000001) {
ret.x = (sign | 0x0000);
return ret;
}

exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);

if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
} else {
shift = 0x7e - exponent;
exponent = 0;
mantissa |= 0x800000;
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);

// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
mantissa = 0;
}
}

ret.x = (sign | (exponent << 10) | mantissa);

return ret;
}

Caffe2 中 float32 与 float16 的转换
http://hebangwen.github.io/2024/03/12/float32-to-float16-caffe/
作者
何榜文
发布于
2024年3月12日
许可协议