Machine Learning & AI 进阶 1m
开始使用 MLX 进行 Apple 芯片机器学习
Get started with MLX for Apple silicon
2025年6月9日
一句话判断
MLX 不是 PyTorch 的替代品,是 Apple Silicon 上做 ML 的正确姿势——统一内存、惰性计算、函数变换,加上 Swift 和 Python 双 API,从 NumPy 到 LLM 都能跑。
这场 Session 讲了什么
MLX 是 Apple 自研的开源数组框架,专为 Apple Silicon 的统一内存架构设计。核心理念很颠覆:CPU 和 GPU 共享内存,不需要手动搬数据。
和传统框架的根本区别:传统框架是”计算跟着数据走”(数组在 CPU 内存就在 CPU 算,在 GPU 内存就在 GPU 算),MLX 是”你在操作里指定设备”(同一个数组可以同时在 CPU 和 GPU 上做不同操作)。
关键特性:
- 惰性计算:操作调用时不实际执行,先构建计算图,需要结果时才 evaluate。可以做图优化,只为用到的部分付费。
- 函数变换:
mx.grad可以对任意函数求梯度,而且可以任意组合(二阶导就是mx.grad(mx.grad(f)))。 - mlx.nn 和 mlx.optimizers:高级 API 类似 PyTorch,nn.Module 基类、Linear 层等几乎一样。
- mx.compile:把计算图融合成单个 GPU kernel,减少 memory bandwidth 和 launch 开销。
- mx.fast:RMS Norm、Scaled Dot Product Attention 等 transformer 核心操作的高性能实现。
- 量化:
mx.quantize支持 4-bit 量化,大幅降低内存占用和推理延迟。 - 分布式计算:
mx.distributed支持多机计算,通过以太网或 Thunderbolt 连接。 - Metal Kernel 自定义:可以直接写 Metal shader 集成到计算图里。
值得深挖的点
-
统一内存是 MLX 的根本优势。不需要把数据从 CPU 拷到 GPU,在 Apple Silicon 的 Mac、iPhone、iPad、Vision Pro 上都能跑。这对 on-device ML 特别关键。
-
API 和 PyTorch 几乎一样。如果你会 PyTorch,写 MLX 的 nn.Module 只有两个小差别。迁移成本极低。
-
Swift API 是一等公民。不是 Python 的包装,而是原生 Swift 实现。可以直接在 iOS/iPadOS/visionOS app 里用 MLX Swift 做推理。
-
Hugging Face 社区很活跃。MLX 社区组织里每天都有新模型上传,LLM、图像生成、语音识别都有现成的。
代码片段
MLX vs PyTorch 的 nn.Module 对比:
# MLX —— 和 PyTorch 几乎一样
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.layer1 = nn.Linear(in_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, out_dim)
def __call__(self, x):
x = self.layer1(x)
x = nn.relu(x) # 差异 1:函数式调用
x = self.layer2(x)
return x # 差异 2:不需要 nn.Softmax
用 mx.compile 加速:
@mx.compile # 装饰器即可,图融合成单个 kernel
def gelu(x):
return 0.5 * x * (1 + mx.tanh(mx.sqrt(2 / mx.pi) * (x + 0.044715 * x**3)))
量化模型:
model = load_model()
model = nn.quantize(model, bits=4, group_size=64)
# 单行命令,所有 Linear 层被量化
Swift 侧使用 MLX:
// 添加 MLX Swift package 到 Xcode 项目
// API 和 Python 几乎一样
import MLX
let a = MLXArray([1, 2, 3])
let b = MLXArray([4, 5, 6])
let c = a + b // 惰性计算,需要时才执行
c.eval()
最佳实践
- 用 pip install mlx 开始,先在 Python 里熟悉 API,再考虑 Swift 集成。
- 需要 on-device 推理时用 MLX Swift,API 一致,不用重新学。
- 大模型用量化,4-bit 量化能大幅降低内存占用,对生成速度也有提升。
- 用 mx.compile 包装反复调用的函数,特别是激活函数、normalization 等。
- 用 mx.fast 里的现成实现(RMS Norm、SDPA),别自己写,它们是高度优化过的。
- 多机训练用 mx.distributed + mlx.launch,支持以太网和 Thunderbolt。
还有什么值得关注
- LM Studio 已经用 MLX 做 on-device LLM 推理,生态比较成熟。
- MLX 的 C++ API 也存在,适合需要极致控制的场景。
- “Explore large language models on Apple silicon with MLX” session 详细讲了 LLM 场景。
- 所有代码都在 MIT 许可下开源在 GitHub。
机器学习