Train your machine learning and AI models on Apple GPUs
Machine Learning & AI 进阶 20m

在 Apple GPU 上训练机器学习与 AI 模型

Train your machine learning and AI models on Apple GPUs

2024年6月10日

在 Apple 官方观看视频

一句话判断

PyTorch MPS backend 今年在三个方面有实质性提升——int8/int4 量化、fused SDPA、统一内存支持——让你可以在 Mac 上用更大的模型做更快的训练,而不需要把 tensor 在 CPU 和 GPU 之间来回拷贝。

这场 Session 讲了什么

Apple Silicon 的 GPU 在机器学习训练上有天然优势:强大的并行计算能力、统一内存架构让 GPU 直接访问大量内存、不需要在多台机器间分布式部署模型权重。今年 Apple 对 PyTorch MPS backend 和 JAX Metal backend 都做了升级,但重点是 PyTorch。

目前 MPS backend 已经能在 HuggingFace 的 Top-50 最热门 transformer 模型上开箱即用,覆盖了 Stable Diffusion、Meta LLaMA、Gemma 等主流模型。Session 重点介绍了三项 PyTorch 改进:8-bit 和 4-bit 整数量化让大模型能塞进设备内存;fused scaled dot product attention 把多头注意力的一系列矩阵运算融合成单次 GPU kernel 调用;统一内存支持消除了 CPU 和 GPU 之间的冗余 tensor 拷贝。

Session 最后演示了一个端到端的工作流:下载 OpenLLaMA v2 3B 模型,用 LoRA adapter 做微调,在 MPS 设备上运行——整个流程只需要几十行代码。

值得深挖的点

统一内存:被低估的训练加速器

大多数 ML 框架在传统 GPU 架构上运行时,需要把 tensor 数据从 CPU 内存拷贝到 GPU 显存,计算完再拷回来。这个拷贝过程既浪费时间又浪费内存——同一份数据在系统中存在两份。Apple Silicon 的统一内存架构天生就不需要这个拷贝步骤,但直到 PyTorch MPS backend 显式支持这一点,开发者才能真正受益。

统一内存支持的直接影响是:训练时的内存占用大幅下降。对于一个 3B 参数的模型,如果训练过程中不需要维护 CPU 和 GPU 两份权重副本,你能用同样的设备跑更大的 batch size 或者更大的模型。更大的 batch size 通常意味着更快的收敛,这是一个正反馈循环。

量化不是推理专属——训练也能用

8-bit 和 4-bit 量化通常被认为只用于推理阶段,但 PyTorch MPS backend 的支持意味着你在训练流程中也能利用量化。具体来说,当你用 LoRA 或类似技术微调一个预训练模型时,基础模型的权重可以用 int8 甚至 int4 存储,只把 LoRA adapter 的参数用 fp16/bf16 保持精度。这样做的直接好处是:原本需要 24GB 内存才能加载的模型,可能 12GB 甚至 8GB 就能跑起来。

Trade-off 也很明确:量化会带来精度损失,虽然 Apple 声称”取决于模型,可能只有很小的甚至没有输出质量下降”,但这需要你针对自己的具体模型和任务做验证。对于 LoRA 微调这种场景,因为基础模型的权重是冻结的,量化带来的影响通常比全量训练小得多。

代码片段

在 MPS 设备上微调语言模型

用 OpenLLaMA v2 + LoRA 在 Mac 上做微调的完整流程。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

# 设置随机种子保证可复现
torch.manual_seed(42)

# 下载模型和分词器
model_name = "openlm-research/open_llama_3b_v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 配置 LoRA adapter
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,                    # LoRA 秩
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]  # 只对 attention 做 LoRA
)

# 将 adapter 附加到模型上
model = get_peft_model(model, lora_config)

# 发送到 MPS 设备(关键步骤)
device = torch.device("mps")
model = model.to(device)

print(f"可训练参数量: {model.print_trainable_parameters()}")

坑:确保 PyTorch 版本 >= 2.1,旧版本的 MPS backend 支持不完整。设置 device = "mps" 后所有 tensor 操作都会在 GPU 上执行。

配置训练参数并启动微调

from transformers import Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling

# 加载训练数据(以 Tiny Shakespeare 数据集为例)
dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="tiny_shakespeare.txt",
    block_size=128
)

# 数据收集器负责组装训练 batch
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # 因果语言模型不做 masked LM
)

# 训练参数
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    # 不需要指定 device,Trainer 会自动检测 MPS
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

trainer.train()

快速切换到 MPS 设备

# 最简单的启用方式:设置默认设备
import torch
torch.set_default_device("mps")

# 之后创建的所有 tensor 都会默认放在 MPS 上
x = torch.randn(3, 3)  # 自动在 MPS 上

坑:部分 numpy 操作返回的 tensor 仍在 CPU 上,需要手动 .to("mps")。另外,MPS backend 对某些算子的支持还在完善中,遇到不支持的操作会 fallback 到 CPU,此时会打印 warning。

最佳实践

如果你已经在用 PyTorch 做训练: 切换成本极低。把 device = "cuda" 改成 device = "mps" 就能跑起来。先用小模型验证流程正确性,再尝试大模型。如果你的训练脚本已经在用 torch.set_default_device(),只需要改一行。

新项目: 如果你的目标是在 Apple 设备上部署模型(Core ML),在 Mac 上训练是最高效的路径——省去了跨平台转换中可能遇到的问题。MLX 框架值得关注,它专为 Apple Silicon 设计,API 风格类似 NumPy,学习曲线平缓。

模型选型建议: 对于 LoRA 微调,3B 参数量的模型(如 OpenLLaMA 3B)在 16GB 内存的 Mac 上可以流畅运行。如果需要更大的模型(7B+),开启 int8 量化是必要的。先验证 baseline 精度,再开量化,对比精度差异。

还有什么值得关注

  • JAX Metal backend 也是今年的更新之一,支持 JIT 编译和类 NumPy 接口,适合已经在用 JAX 的团队
  • MLX 框架提供了 Python、Swift、C、C++ 多语言绑定,如果需要在 app 内直接集成训练逻辑可以考虑
  • HuggingFace Top-50 模型的 MPS 加速覆盖意味着大部分主流 LLM 和 Diffusion 模型已经可以直接在 Mac 上跑
WWDC 2024