Torch 编译优化
Torch Compilation Optimization
解决手写 CUDA kernel 成本高的问题,通过 @torch.compile 装饰器让编译器自动进行算子融合优化
子问题
1.编译器兼容的代码模式
2.动态 shape 处理
3.编译缓存与预热
4.与 CUDA Graph 的兼容性
5.不可编译算子的等价替换(如 multinomial → Gumbel-max)
6.条件分支导致的 graph break 规避(拆分编译入口)
7.残差连接与归一化的跨算子融合
各项目的解法1 solutions
Signals
横向对比
| 维度 | nano-vllm |
|---|---|
| 编译粒度 | @torch.compile 装饰 4 个 leaf module 方法,不编译整个模型 |
| 融合策略 | 编译器自动融合 memory-bound 算子,GEMM 交给 cuBLAS |
| 采样实现 | Gumbel-max trick 替代不可编译的 torch.multinomial |
| 残差处理 | add_rms_forward 将残差加法与 RMSNorm 融合为单 kernel |
| 预热机制 | warmup_model 在 CUDA Graph capture 前触发全部编译 |
| 回退策略 | enforce_eager 配置项可完全禁用编译,回退到 eager 模式 |
最佳实践
1.在计算密集的小模块上使用 @torch.compile 而非整个模型
2.保持被编译函数的输入 shape 稳定以提高缓存命中率
3.将含 if/else 的 forward 拆成多个纯算术方法分别编译
4.使用 in-place 操作(mul_, add_, div_)减少编译后 kernel 的临时显存分配
5.在 CUDA Graph capture 前执行 warmup 确保编译完成