Optimize machine learning for Metal apps
Machine Learning & AI 进阶 20m

为 Metal 应用优化机器学习性能

Optimize machine learning for Metal apps

2023年6月5日

在 Apple 官方观看视频

一句话判断

PyTorch 2.0 MPS 后端进入 Beta 阶段,配合新的性能分析工具和自定义算子机制,以及自动混合精度训练支持——在 Mac 上用 GPU 跑深度学习模型的体验已经相当成熟。

这场 Session 讲了什么

GPU 软件工程师 Denis Vieriu 介绍了 Metal 机器学习 API 的 2023 年更新,重点覆盖 PyTorch、TensorFlow 和 JAX 的 Metal 加速后端,以及 MPSGraph 的新特性。

PyTorch MPS 后端:PyTorch 2.0 的 MPS 后端已进入 Beta 阶段。支持最常用的 60 个 Torch 算子,包括 grid sampler、triangular solve、topk 等。测试覆盖率大幅提升,包括梯度测试和 ModuleInfo 测试。多个知名模型已正式采用 MPS 作为 macOS 后端,包括 WhisperAI、YOLO、Stable Diffusion 等。

性能分析工具:PyTorch nightly 版本新增 MPS 算子分析支持,使用 OS signpost 在 Metal System Trace 中可视化算子执行时间、CPU-GPU 数据拷贝和 CPU 回退操作。

自定义算子:四步流程实现自定义 GPU 算子——用 Objective-C 和 Metal 实现算子、创建 Python 绑定、编译扩展、在训练脚本中导入使用。

自动混合精度(AMP):支持 FP16/BF16 混合精度训练,减少内存占用并加速训练,同时保持模型质量。

JAX GPU 加速:新增 JAX 的 Metal 加速后端,扩展了对更多 ML 框架的支持。

值得深挖的点

CPU 回退的性能陷阱:当某个算子不支持 MPS 后端时,PyTorch 会回退到 CPU 执行。这导致数据在 CPU 和 GPU 之间来回拷贝,GPU 空闲等待。Session 用 Softshrink 算子演示了这个问题——它在 Metal System Trace 中表现为 GPU 时间线上的大量空白间隙。

自定义算子的性能提升:为 Softshrink 编写自定义 Metal 内核后,消除了所有 CPU 回退和数据拷贝。模型运行效率显著提升。实现自定义算子的代码量很小——PYBIND11 绑定只需两行代码。

混合精度的数据类型支持:Metal 支持 FP32、FP16、BF16 等多种浮点格式。混合精度训练自动在计算密集的层使用低精度,在精度敏感的层保持高精度。BF16 相比 FP16 有更大的数值范围,不容易溢出。

社区贡献模式:PyTorch 生态中的开发者已经在为 MPS 后端贡献新的算子实现,包括 histogram、group_norm、signbit 等。这说明 Mac GPU 加速在社区中已经有了实质性的采纳。

代码片段

# PyTorch MPS 后端 - 性能分析
import torch

# 启用 MPS 性能分析
torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True
).start()

# 你的模型代码...
model.train()

# 结束分析
torch.profiler.profile().stop()
# 在 Metal System Trace 中查看结果
# 自定义 Metal 算子 - Python 绑定
import torch
from torch.utils.cpp_extension import load

# 编译自定义扩展
softshrink_ext = load(
    name='softshrink_mps',
    sources=['softshrink.mm', 'softshrink_kernel.metal'],
    extra_cflags=['-ObjC++'],
    extra_ldflags=['-framework', 'Metal', '-framework', 'Foundation']
)

# 在模型中使用自定义算子
model = nn.Sequential(
    nn.Linear(784, 256),
    softshrink_ext.SoftshrinkMPS(),  # 替换原来的 CPU 版本
    nn.Linear(256, 10)
).to('mps')
# 自动混合精度训练
from torch.cuda.amp import autocast, GradScaler

model = MyModel().to('mps')
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()

for data, target in dataloader:
    data, target = data.to('mps'), target.to('mps')
    optimizer.zero_grad()

    with autocast(dtype=torch.float16):  # 混合精度
        output = model(data)
        loss = criterion(output, target)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

最佳实践

  • 用 Metal System Trace 定位瓶颈:先跑分析器,找到 CPU 回退和数据拷贝的位置,再有针对性地编写自定义算子。
  • 优先检查模型是否已有 MPS 支持:WhisperAI、YOLO、Stable Diffusion 等模型已经官方支持 MPS 后端。用这些模型时直接 device='mps' 就行。
  • 混合精度训练可减少一半显存:在 M1/M2 Mac 上训练大模型时,混合精度能显著减少内存占用,让你用更大的 batch size。
  • 自定义算子注意线程安全:使用 MPS 后端的 get_dispatch_queue API 确保多线程提交被序列化。
  • 关注 PyTorch nightly 版本:最新的分析工具和算子支持通常先在 nightly 版本中出现。

还有什么值得关注

  • MPSGraph 也在持续演进,为 CoreML 等 inference 框架提供底层支持。
  • TensorFlow Metal 后端同样在更新,覆盖了更多模型。
  • JAX GPU 加速是今年新增的,说明 Apple 在积极扩展 ML 框架的覆盖范围。
  • Apple 在机器学习生态的投入是全方位的——从硬件(M 系列芯片的 GPU)到框架(MPSGraph)到工具(性能分析器)到上层框架集成(PyTorch/TF/JAX)。
WWDC 2023