为 Metal 应用优化机器学习性能
Optimize machine learning for Metal apps
2023年6月5日
一句话判断
PyTorch 2.0 MPS 后端进入 Beta 阶段,配合新的性能分析工具和自定义算子机制,以及自动混合精度训练支持——在 Mac 上用 GPU 跑深度学习模型的体验已经相当成熟。
这场 Session 讲了什么
GPU 软件工程师 Denis Vieriu 介绍了 Metal 机器学习 API 的 2023 年更新,重点覆盖 PyTorch、TensorFlow 和 JAX 的 Metal 加速后端,以及 MPSGraph 的新特性。
PyTorch MPS 后端:PyTorch 2.0 的 MPS 后端已进入 Beta 阶段。支持最常用的 60 个 Torch 算子,包括 grid sampler、triangular solve、topk 等。测试覆盖率大幅提升,包括梯度测试和 ModuleInfo 测试。多个知名模型已正式采用 MPS 作为 macOS 后端,包括 WhisperAI、YOLO、Stable Diffusion 等。
性能分析工具:PyTorch nightly 版本新增 MPS 算子分析支持,使用 OS signpost 在 Metal System Trace 中可视化算子执行时间、CPU-GPU 数据拷贝和 CPU 回退操作。
自定义算子:四步流程实现自定义 GPU 算子——用 Objective-C 和 Metal 实现算子、创建 Python 绑定、编译扩展、在训练脚本中导入使用。
自动混合精度(AMP):支持 FP16/BF16 混合精度训练,减少内存占用并加速训练,同时保持模型质量。
JAX GPU 加速:新增 JAX 的 Metal 加速后端,扩展了对更多 ML 框架的支持。
值得深挖的点
CPU 回退的性能陷阱:当某个算子不支持 MPS 后端时,PyTorch 会回退到 CPU 执行。这导致数据在 CPU 和 GPU 之间来回拷贝,GPU 空闲等待。Session 用 Softshrink 算子演示了这个问题——它在 Metal System Trace 中表现为 GPU 时间线上的大量空白间隙。
自定义算子的性能提升:为 Softshrink 编写自定义 Metal 内核后,消除了所有 CPU 回退和数据拷贝。模型运行效率显著提升。实现自定义算子的代码量很小——PYBIND11 绑定只需两行代码。
混合精度的数据类型支持:Metal 支持 FP32、FP16、BF16 等多种浮点格式。混合精度训练自动在计算密集的层使用低精度,在精度敏感的层保持高精度。BF16 相比 FP16 有更大的数值范围,不容易溢出。
社区贡献模式:PyTorch 生态中的开发者已经在为 MPS 后端贡献新的算子实现,包括 histogram、group_norm、signbit 等。这说明 Mac GPU 加速在社区中已经有了实质性的采纳。
代码片段
# PyTorch MPS 后端 - 性能分析
import torch
# 启用 MPS 性能分析
torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
).start()
# 你的模型代码...
model.train()
# 结束分析
torch.profiler.profile().stop()
# 在 Metal System Trace 中查看结果
# 自定义 Metal 算子 - Python 绑定
import torch
from torch.utils.cpp_extension import load
# 编译自定义扩展
softshrink_ext = load(
name='softshrink_mps',
sources=['softshrink.mm', 'softshrink_kernel.metal'],
extra_cflags=['-ObjC++'],
extra_ldflags=['-framework', 'Metal', '-framework', 'Foundation']
)
# 在模型中使用自定义算子
model = nn.Sequential(
nn.Linear(784, 256),
softshrink_ext.SoftshrinkMPS(), # 替换原来的 CPU 版本
nn.Linear(256, 10)
).to('mps')
# 自动混合精度训练
from torch.cuda.amp import autocast, GradScaler
model = MyModel().to('mps')
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()
for data, target in dataloader:
data, target = data.to('mps'), target.to('mps')
optimizer.zero_grad()
with autocast(dtype=torch.float16): # 混合精度
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
最佳实践
- 用 Metal System Trace 定位瓶颈:先跑分析器,找到 CPU 回退和数据拷贝的位置,再有针对性地编写自定义算子。
- 优先检查模型是否已有 MPS 支持:WhisperAI、YOLO、Stable Diffusion 等模型已经官方支持 MPS 后端。用这些模型时直接
device='mps'就行。 - 混合精度训练可减少一半显存:在 M1/M2 Mac 上训练大模型时,混合精度能显著减少内存占用,让你用更大的 batch size。
- 自定义算子注意线程安全:使用 MPS 后端的 get_dispatch_queue API 确保多线程提交被序列化。
- 关注 PyTorch nightly 版本:最新的分析工具和算子支持通常先在 nightly 版本中出现。
还有什么值得关注
- MPSGraph 也在持续演进,为 CoreML 等 inference 框架提供底层支持。
- TensorFlow Metal 后端同样在更新,覆盖了更多模型。
- JAX GPU 加速是今年新增的,说明 Apple 在积极扩展 ML 框架的覆盖范围。
- Apple 在机器学习生态的投入是全方位的——从硬件(M 系列芯片的 GPU)到框架(MPSGraph)到工具(性能分析器)到上层框架集成(PyTorch/TF/JAX)。