Get to know Create ML Components
Machine Learning & AI 进阶 20m

Create ML Components 初探:用组合构建自定义 ML 任务

Get to know Create ML Components

2022年6月6日

在 Apple 官方观看视频

一句话判断

Create ML Components 把预定义的 ML 任务拆成了可组合的零件——你终于可以不依赖 Create ML app 就构建图像回归、表格回归等自定义 pipeline,而且训练和推理用同一套 Swift API。

这场 Session 讲了什么

Create ML 提供了一系列预定义的 ML 任务:图像分类、声音分类、动作分类等。但如果你想做的任务不在列表里怎么办?比如给香蕉的成熟度打分(图像回归),或者预测牛油果的价格(表格回归)?

Create ML Components 就是为此而生的。它把 ML 任务拆成了两种基本单元:Transformer(对数据做变换)和 Estimator(从数据中学习并产生 Transformer)。你用 appending 方法把组件串起来,就像搭积木一样构建自定义任务。

Session 用两个完整 demo 展示了这个框架的威力。第一个是图像回归器:给香蕉照片的成熟度打 1-10 分。Feature extractor + Linear Regressor,几行代码就搞定。然后通过数据增强(随机旋转和缩放)和 Vision 框架的显著性裁剪来提升模型精度。第二个是表格回归器:预测牛油果价格。用 ColumnSelector 选出需要标准化的数值列,用 StandardScaler 做归一化,再接上 BoostedTreeRegressor。

值得深挖的点

你的模型就是你的代码。 Create ML Components 的一个独特设计是,训练好的模型不是一个独立的文件,而是代码 + 参数。你需要保存 task definition 和 trained parameters,推理时两者缺一不可。这和 Core ML 的「一个 .mlmodel 文件走天下」很不一样。好处是灵活性极高——你可以在 task definition 里嵌入自定义 transformer(比如显著性裁剪),推理时自动执行。代价是部署时需要更多代码。

不过你也可以选择导出为 Core ML 模型,这样就只需要一个文件。但 Session 明确说了,custom transformers 和 custom estimators 不支持 Core ML 导出。所以如果你的 pipeline 包含自定义组件,就只能用 Create ML Components 的原生方式部署。

数据加载的便利性。 AnnotatedFiles 类型让从文件名中提取标注变得非常简单——你只需要告诉它分隔符和标注所在的位置。然后 mapFeaturesmapAnnotations 做类型转换,randomSplit 做训练/验证集划分。整个数据准备流程非常流畅。

代码片段

构建图像回归器:

import CreateMLComponents
import CoreImage

struct BananaRipenessRegressor {
    // 组合 estimator:特征提取 + 线性回归
    static let estimator = ImageFeaturePrint()
        .appending(LinearRegressor())
    
    static func train() async throws -> some Transformer<CIImage, Float> {
        // 从文件名加载标注(banana-5.jpg → 标注 "5")
        let data = try AnnotatedFiles(
            labeledByNamesAt: trainingURL,
            separator: "-", index: 1, type: .image
        )
        .mapFeatures(ImageReader.read)        // URL → CIImage
        .mapAnnotations({ Float($0)! })       // String → Float
        
        let (training, validation) = data.randomSplit(by: 0.8)
        let model = try await estimator.fitted(to: training, validateOn: validation)
        try estimator.write(model, to: parametersURL)
        return model
    }
}

添加自定义 Transformer 提升精度:

// 自定义显著性裁剪 transformer
struct SaliencyCropper: Transformer {
    // 遵循 Transformer 协议只需实现 applied 方法
    func applied(to image: CIImage) async throws -> CIImage {
        // 用 Vision 框架找到最显著的物体并裁剪
        // 如果找不到,返回原图
        guard let salientRect = findSalientObject(in: image) else {
            return image
        }
        return image.cropped(to: salientRect)
    }
}

// 把自定义 transformer 插入 pipeline
static let estimator = SaliencyCropper()           // 先裁剪
    .appending(ImageFeaturePrint())                 // 再提取特征
    .appending(LinearRegressor())                   // 最后回归

// 注意:自定义 transformer 会同时用于训练和推理

表格回归预测牛油果价格:

// 用 ColumnSelector 处理特定列
let volumeNormalizer = ColumnSelector(
    columns: ["volume"],
    estimator: OptionalUnwrapper()
        .appending(StandardScaler<Double>())
)

// 数值归一化 + 提升树回归
let task = volumeNormalizer.appending(
    BoostedTreeRegressor<String>(
        annotationColumnName: "price",
        featureColumnNames: ["type", "region", "volume"]
    )
)

// 推理:构造只包含所需列的 DataFrame
func predict(type: String, region: String, volume: Double) async throws -> Double {
    let model = try task.read(from: parametersURL)
    let df: DataFrame = ["type": [type], "region": [region], "volume": [volume]]
    let result = try await model(df)
    return result[ColumnID("price", Double.self)][0]!
}

最佳实践

  • AnnotatedFileslabeledByNamesAt 从文件名或目录结构自动提取标注。
  • randomSplit(by:) 划分训练集和验证集,0.8 是个不错的起点。
  • fitted 方法中传入 event handler 监控训练过程中的验证指标。
  • 数据不够时,用 flatMap 做数据增强(旋转、缩放),注意只在训练集上增强。
  • 部署方式二选一:导出 Core ML 模型(不支持自定义组件),或者打包 Swift package(支持所有组件)。

还有什么值得关注

  • 分类器有 LogisticRegressionClassifier 和 FullyConnectedNetworkClassifier 可选,后者适合更复杂的分类边界。
  • One-hot encoding 适合类别少的列,Ordinal encoding 适合类别多的列。
  • Create ML Components 同时支持 macOS、iOS、iPadOS 和 tvOS 上的训练与推理。
WWDC 2022