bfloat16 精度损失(II)

在上一篇 记一个关于 RMSNorm 实现上的细节 中,我们讨论了 LayerNorm 和 RMSNorm 使用 float16 格式会导致的数值溢出问题,而 bfloat16 由于数值范围较大,通常不会出现溢出问题。那么本文我们讨论 bfloat16 较之于 float16,由于数据舍入而带来的精度损失问题。

浮点数表示方式

浮点数的表示方式是通过科学计数法来表示一个实数。float16 和 bfloat16 都是浮点数格式,但它们的精度和范围不同,具体比特划分如下。

比特位 float16 bfloat16
符号位 1 1
指数位 5 8
尾数位 10 7
max 65504 3.38953e+38
min 6.10351e-05 1.17549e-38
resolution 0.001 0.01

bf16 vs fp16

可以看到,bfloat16 在指数位上使用了 8 位,而 float16 只使用了 5 位,这使得 bfloat16 在表示更大范围的数值时更加灵活。但是 bfloat16 的尾数位更短,相比于 float16,它在进行数值运算时会引入更多的舍入误差。

实测

使用 qwen2.5-omni-3b 中的 audio tower 进行实测,将 nn.LayerNorm 全部配置为自定义的 LayerNorm,并观察替换前后的相似度。

1
2
3
4
5
6
7
diff.max() = tensor(0.9375, dtype=torch.bfloat16), diff.mean() = tensor(0.0204, dtype=torch.bfloat16), diff.std() = tensor(0.0277, dtype=torch.bfloat16)
cosine_sim[:10] = tensor([0.9961, 1.0000, 1.0000, 0.9961, 1.0000, 0.9883, 1.0000, 0.9766, 1.0000,
1.0000], dtype=torch.bfloat16)
l2_dist = tensor(22., dtype=torch.bfloat16)
l2_sim = tensor(0.0435, dtype=torch.bfloat16)
sqnr = tensor(26.3750, dtype=torch.bfloat16)
sqnr.item() = 26.375 dB

可以看到的是,替换后的 LayerNorm 和原始的 nn.LayerNorm 在数值上有一定的差异,最大差异为 0.9375,平均差异为 0.0204,标准差为 0.0277,接近 0.01 的数量级。余弦相似度接近 1,说明两者的方向相似,但在数值上存在一定的偏差。在部署层面,某些精度要求比较高的模型可能会受到影响。

bfloat16 的缺点是它的尾数位较短,这意味着在进行数值运算时,可能会引入更多的舍入误差。特别是在进行多次运算时,这些误差可能会累积,从而导致最终结果与预期不符。

float16 的缺点是可能会有溢出问题,尤其是在处理较大数值时。如果在不溢出的范围内,float16 的舍入误差会更小,并且可以完美地表示 bfloat16 的数值。

附录

替换代码见 replace_layernorm.py


bfloat16 精度损失(II)
http://hebangwen.github.io/2025/06/30/layernorm-part-II/
作者
何榜文
发布于
2025年6月30日
许可协议