Discover machine learning enhancements in Create ML
Machine Learning & AI 进阶 20m

探索 Create ML 的机器学习增强

Discover machine learning enhancements in Create ML

2023年6月5日

在 Apple 官方观看视频

一句话判断

Create ML 今年带来了三件重要的事:BERT 多语言文本分类器、多标签图像分类器(Multi-Label Image Classifier)、以及数据增强 API——用更少的数据训练出更好的模型。

这场 Session 讲了什么

Create ML 团队的 David Findlay 介绍了 Create ML 框架和应用的 2023 年更新。

文本分类改进:新的 BERT 嵌入模型在数十亿标注文本上预训练,支持多语言文本分类。这意味着你的训练数据可以包含多种语言,且单语言分类器的准确率也有所提升。BERT 模型支持 iOS 17、iPadOS 17 和 macOS Sonoma。

图像分类改进:最新版本的 Apple Neural Scene Analyzer 可作为特征提取器。新的提取器输出嵌入尺寸更小,带来更快的训练速度、更高的准确率和更低的内存占用。

全新的多标签图像分类器:传统图像分类每张图片只能有一个标签。多标签分类器可以为一张图片预测多个标签——比如一张图片同时包含”狗”、“玩具”、“草地”和”公园”。Session 用一个多肉植物分类器做了完整演示,包括数据准备、训练、评估和预览。

数据增强 API:当训练数据有限时,增强 API 可以通过旋转、翻转、裁剪等变换生成额外的训练样本,提升模型质量。

值得深挖的点

多标签 vs 单标签 vs 目标检测:单标签分类为每张图片选一个最佳标签。目标检测定位图片中的物体(画边界框)。多标签分类预测一组标签,不需要画框——适合描述场景属性(如”户外”、“公园”)和多个共存的物体。

MAP(平均精度均值)评估指标:多标签分类用 MAP 而不是准确率来评估。MAP 同时考虑精确率和召回率,对所有标签取平均。Session 中的多肉分类器在训练集上达到 97% MAP,验证集 93% MAP。

置信度阈值:多标签分类器为每个标签维护独立的置信度阈值。预测时,只有置信度超过阈值的标签才会被输出。Create ML 应用的 Metrics 标签页展示了每个标签的阈值、精确率、召回率和误报/漏报情况。

数据增强的实际效果:增强不是万能的——关键是选择与你的应用场景匹配的变换。比如识别植物种类时,水平翻转是合理的增强,但垂直翻转就不一定了。

代码片段

// 多标签图像分类器训练
import CreateML

let trainingData = try MLDataTable(contentsOf: trainingURL)
// JSON 格式: {"image": "path/to/image.jpg", "labels": ["aloe", "cactus", "indoors"]}

let classifier = try MLImageClassifier(
    trainingData: trainingData,
    modelParameters: MLImageClassifier.ModelParameters(
        validation: .split(strategy: .automatic),
        augmentation: [
            .flip(horizontal: true),  // 水平翻转
            .rotate(angle: .init(degrees: 15)),  // 随机旋转
            .crop(size: CGSize(width: 224, height: 224))  // 随机裁剪
        ]
    )
)

// 导出 Core ML 模型
let metadata = MLModelMetadata(
    author: "Your Name",
    shortDescription: "多肉植物多标签分类器",
    version: "1.0"
)
try classifier.write(to: modelURL, metadata: metadata)
// 使用训练好的多标签分类模型
let model = try SucculentClassifier(configuration: MLModelConfiguration())

let prediction = try model.prediction(image: inputImage)
// prediction 包含多个标签和各自的置信度
// 例如: ["aloe": 0.9, "cactus": 0.7, "indoors": 0.85]
// 使用 BERT 多语言文本分类
import CreateML

let trainingData = try MLDataTable(contentsOf: jsonURL)
// 数据格式: {"text": "文章内容", "label": "sports"}

let classifier = try MLTextClassifier(
    trainingData: trainingData,
    modelParameters: MLTextClassifier.ModelParameters(
        validation: .split(strategy: .automatic),
        algorithm: .transferLearning(
            featureExtractor: .language),  // 使用新的 BERT 嵌入
        language: .automatic  // 自动检测语言
    )
)

最佳实践

  • 多标签分类用 JSON 标注:每张图片标注为一个标签数组。可以混合单标签和多标签的样本。
  • 检查 Metrics 标签页的每个类别:关注精确率和召回率低的类别,分析误报和漏报样本,针对性补充训练数据。
  • 谨慎选择数据增强变换:只使用对你的应用场景有意义的变换。不当的增强可能引入噪声。
  • 利用 Preview 标签页测试:用训练集之外的图片测试模型,验证泛化能力。
  • BERT 模型考虑部署平台:BERT 需要 iOS 17+,如果需要支持旧系统,考虑使用动态嵌入模型。

还有什么值得关注

  • Create ML 应用现在可以直接预览多标签分类结果,包括每个标签的置信度。
  • Apple Neural Scene Analyzer 的更新不仅在 Create ML 中可用,也改善了系统级图像理解能力。
  • 数据增强 API 不仅支持图像变换,框架设计上为其他模态的增强留了空间。
  • 多标签分类器对遮挡场景的处理有限——Session 中一个被其他植物遮挡的多肉品种就没被正确识别。
WWDC 2023