混合精度与 FP8 训练
Mixed Precision & FP8 Training
利用低精度浮点格式加速训练同时保持数值稳定性
子问题
1.动态缩放策略
2.前向/反向不同精度选择
3.torch.compile 兼容性
4.评估时精度切换
5.cuBLAS FP8 内核的内存布局要求(row-major vs column-major)
6.float64 中间精度确保 compile 与 eager 数值一致
7.meta device 零拷贝模块转换避免额外显存分配
各项目的解法1 solutions
Signals
横向对比
| 维度 | nanochat |
|---|---|
| 协议标准 | 直接调用 PyTorch 内置 torch._scaled_mm + float8 dtype,无外部依赖 |
| 能力声明 | 仅支持 tensorwise 缩放,不支持 rowwise/axiswise |
| 通信方式 | autograd.Function 封装三路 GEMM,@allow_in_graph 不透明编译 |
| 部署模式 | 单文件 ~150 行即插即用,替换 nn.Linear 零拷贝 |
| 执行适配 | module_filter_fn 过滤不适合 FP8 的层(维度 < 128 或非 16 倍数) |
| 量化策略 | 动态 tensorwise 缩放,float64 除法确保 compile/eager 一致 |
| 精度分层 | forward e4m3fn+fast_accum,backward e5m2+精确累加 |
| 评估回退 | disable_fp8 上下文管理器临时换回 nn.Linear 共享权重 |
最佳实践
1.tensorwise 缩放比 rowwise 更快且足够准确
2.forward 用 use_fast_accum=True 加速,backward 用 False 保精度
3.@allow_in_graph 让 compile 视 FP8 为不透明节点,减少编译开销
4.评估时 disable_fp8 共享权重切回 nn.Linear,零内存拷贝