问题域/PD-427

混合精度与 FP8 训练

Mixed Precision & FP8 Training

利用低精度浮点格式加速训练同时保持数值稳定性

子问题

1.动态缩放策略

2.前向/反向不同精度选择

3.torch.compile 兼容性

4.评估时精度切换

5.cuBLAS FP8 内核的内存布局要求(row-major vs column-major)

6.float64 中间精度确保 compile 与 eager 数值一致

7.meta device 零拷贝模块转换避免额外显存分配

各项目的解法1 solutions

Signals

横向对比

维度nanochat
协议标准直接调用 PyTorch 内置 torch._scaled_mm + float8 dtype,无外部依赖
能力声明仅支持 tensorwise 缩放,不支持 rowwise/axiswise
通信方式autograd.Function 封装三路 GEMM,@allow_in_graph 不透明编译
部署模式单文件 ~150 行即插即用,替换 nn.Linear 零拷贝
执行适配module_filter_fn 过滤不适合 FP8 的层(维度 < 128 或非 16 倍数)
量化策略动态 tensorwise 缩放,float64 除法确保 compile/eager 一致
精度分层forward e4m3fn+fast_accum,backward e5m2+精确累加
评估回退disable_fp8 上下文管理器临时换回 nn.Linear 共享权重

最佳实践

1.tensorwise 缩放比 rowwise 更快且足够准确

2.forward 用 use_fast_accum=True 加速,backward 用 False 保精度

3.@allow_in_graph 让 compile 视 FP8 为不透明节点,减少编译开销

4.评估时 disable_fp8 共享权重切回 nn.Linear,零内存拷贝