大多数浏览器和
Developer App 均支持流媒体播放。
-
在 Create ML 中改进对象检测模型
在 Create ML 中训练自定义 Core ML 模型用于对象检测时,你可以把图像理解引入自己的 app。了解迁移学习如何让你能够以更少的训练数据构建更小的模型。我们还会详细介绍 Create ML 中的一些高级参数,它们能帮助你控制输入图像的训练迭代次数、批量大小和网格大小,让你更好地控制模型的精确度。 关于对象检测的介绍,请观看 WWDC 2019 年度的视频《在 Create ML 中训练对象检测》。
资源
相关视频
WWDC19
-
下载
(在 Create ML 中改进对象检测模型) 大家好 我叫 Shreya Jain 是 Create ML 团队的工程师 今天我们要来看看 对象检测模板中的一些增强性能 并利用它们来创建更佳的模型 如果你们对 Create ML 中的对象检测 还不太熟悉 我建议你们看看 WWDC 2019 中的这段视频 (在 Create ML 中训练对象检测) 对象检测能够实现一些浸入式 app 体验 你们可以构建一款 app 来帮助人们进行垃圾分类 给宠物猫试戴虚拟眼镜
甚至是一款可根据检测到的食材 推荐食谱的 app 为这款 app 构建模型 是观察 Create ML 及其新功能 运行情况的好办法
我们为对象检测做了一些重要改进 现在你们可以用更少的训练数据 来训练精确和更小的模型 还可以列出更多的配置选项来自定义训练
我们开始吧
首先 我要从聚焦中打开 Create ML app 首先看到的是模板挑选器 我在这里选择对象检测 这会打开一个对话框 用于输入 关于 Create ML 项目的详细信息 我把这个项目命名为 “FindMyRecipe” 添加一条有助于检测到食材的描述 在创建该项目之前 我可以选择改变它的位置 接下来就到了设置标签页面 训练之前可以在这里调整数据和配置选项 加载数据之前 我给你们展示一下完整的数据准备过程
对象检测数据必须存储在一个文件夹里 这个文件夹要包含所有的训练图像 和 JSON 文件里的注释 我们可以用下面的图像为例 对 annotations.json 中的内容 加以理解 它包括两个对象: 一片面包和一个西红柿 每个对象注释都由该对象的标签 及其在图像中的定位组成 定位是以图像左上角作为参照的
训练数据图像里的所有对象 都可以用这种方式进行注释
所有这些注释都以这种格式 被添加到一个单一 JSON 文件中
我会使用该信息准备我的训练数据
在数据准备好之后 我可以在 Create ML app 里加载它
点击视图按钮会显示数据集的类分布 你们可以看到 我的类是 西红柿、奶酪、面包和罗勒 回到设置标签 这里可以有选择地提供验证数据 确保模型在不可见数据上顺利运行 我在这里把验证数据设置为自动 让 Create ML 使用一小部分我的数据集
这里还有新的训练参数 能让你更好地控制模型的训练 它们分别是算法、迭代次数 批量大小和网格大小
训练有两种算法
第一种是 full 网络 让我们更深入地看一下 full 网络
full 网络在 2019 年 被引入 Create ML 从那时起一直是默认训练算法
这种算法是建立在 YOLOv2 构架上的
该网络的所有参数 都使用了你的数据来进行训练
生成的 Core ML 模型 会把所有学到的参数进行编码
该 Core ML 模型已经被量化 来存储 16 位精度的权重 生成的模型大小 是我们之前得到模型的一半 所以一个之前大概 65 兆字节大小的模型 现在只有 33 兆字节
当你有大量的训练数据 比如每个类超过 200 个边界框时 推荐使用这种算法 生成的模型是向后兼容的 可以一直向后追溯到 iOS 12 我们希望能让你们用更少的训练数据 构建出高度精确的模型 所以我们为对象检测 引入了迁移学习算法
迁移学习会利用操作系统中 已有的机器学习模型 比如 Photos app 里 就有支持查询和回忆的模型
Photos 使用的预训练 backbone 之一 名为对象打印 它的训练使用了海量的不同数据 借助迁移学习 你们可以利用它来减少数据需求
Create ML 中的迁移学习算法 使用了对象打印 和 head 网络 只有 head 网络的训练使用了你的数据 减少了需要学习的参数数量
因此 Core ML 模型 只包含 head 网络参数 这让你的模型比 full 网络小五倍 在 2019 年同样的模型有 65 兆字节大 量化之后是 33 兆字节 使用迁移学习算法将只有 7 兆字节
在你只有有限的数据 而且想要一个轻量级模型时 迁移学习算法是一个很好的选择 哪怕每个类只有 80 个训练样本 它也能做得很好 生成的模型要求安装 iOS 14 才能使用 OS 附带的对象打印
算法只是新配置之一 同时增加的还有像迭代次数 和批量大小这样的参数
迭代次数指的是模型参数被更新的次数 根据数据集大小选出一个默认值 为了满足特定用例 如果模型尚未收敛 你可以增加迭代次数 或者如果模型开始阶段运行良好 可以减少它们的次数
批量大小指的是一次迭代中 使用的训练样本数量 默认值是根据硬件限制选择的 尽管批量大小的数值越高越好 但你可能只想使用默认值 或根据性能限制减少该数值
最后对于 full 网络 你可以自定义网格大小 理解网格大小这一概念需要知道 预测是如何为 full 网络工作的 让我们更深入地看一下
从这幅图像开始…
把它传递给一个 已经过训练的 full 网络模型…
生成一些带有边界框的预测对象
为了找到图像中的对象 模型使用了一个网格和一组锚框 指定的网格决定了输入图像的高宽比 以及模型要去哪找检测对象 举个例子 让我们看看模型 在 5 × 5 网格维数时如何工作
图像的大小会被重新调整以适应网格… 在本例中是一个正方形图像… 然后分隔成预定数量的单元
然后网络会生成预测 每个网格单元一个
每个预测都包括了下列信息: 单元内是否含有对象 该对象的类及其边界框 在每个对象都与一个网格单元关联时 YOLO 能够很好地处理多个对象 你们可以从这幅图像中看到 香蕉和小狗的中心落进了同一个单元里 由于每个单元只能预测一个类 所以只能要么选择香蕉 要么选择小狗
为了能够同时预测香蕉和小狗 就需要定义锚点框 锚点框具有指定的高宽比 可以在一个网格单元中检测多个对象
Create ML 使用了 13 × 13 的默认网格维数 一共有 169 个单元 对每个单元会固定评估一组 15 个 高宽比不同的锚点框 因此 默认模型每幅图像 会做出总共 2535 个预测
来看看这副骰子图像 想想如果采用 3 × 3 网格维数 对象检测会如何工作
由于一个单元里 出现了多个高宽比相近的骰子 只会对其中一个进行检测 (网格大小 9 × 9) 网格大小的数值越高 检测到的骰子就越多
但这样会增加每幅图像的预测数量 在改变网格大小时还要重点考虑 计算成本问题
对于这样一幅非正方形图像 其维数是 1500 × 800 对这副图像使用 8 × 8 网格 会导致信息丢失 以及对象的自然形状发生扭曲 这会妨碍模型在训练中 捕捉到更多的细粒度模式 并干扰其预测能力
选择 15 × 8 网格尺寸 能保留原始输入图像的高宽比 生成一个学习了更多信息的模型 取得更好的结果
再回过头来看看 FindMyRecipe 项目的模型训练 我可以选择迁移学习算法 迭代次数设置成 1000 批量大小设置成自动
点击播放按钮 模型就开始进行训练 训练标签会显示批量正在准备中 这一步会执行一组标准图像增强 帮助提高鲁棒性 和针对真实数据的泛化能力 很快就出现了一张图表 绘制出了每次迭代的损失值
随着训练的进行 可以点击 snapshot 按钮 获取一个当时的模型 snapshot 可以帮助检查训练进展
我可以利用这个模型 来预览针对一些图像的预测
每一幅图像的 模型预测都显示在预览标签中 我可以点击这些边界框 在底部查看每个类的置信值 还可以使用一个 snapshot 在 app 内进行试验
训练完成后 训练评估指标和验证数据 会显示在评估标签中 这些数字是什么意思?
对象检测模型的评估需要两层 我们想要的不仅是正确的标签 而且它们必须位于合适的位置 让边界框完全匹配注释框很困难 这时候就需要一个能捕获 预测框有多么 靠近注释框的数字
这种测量需要用到一个 SCORE 函数 名叫 intersection-over-union 它的值位于 0% 到 100% 之间…
0 代表没有重叠 100 代表完全重叠
一个被认为是正确的预测 需要具有正确的类标签 和大于一个预定义阈值的 intersection-over-union SCORE 函数
如果 intersection-over-union SCORE 函数小于阈值 或预测类不正确 整体预测就会不正确 这条信息被用来计算一个 名为平均精度均值 也叫 mAP 的指标
现在我要回到评估标签去看看这些数字
它们代表每个类用两个阈值 计算出来的的平均精度 一个被固定为 50% 另一个会随着不同的阈值变化 我们数据集的整体平均精度均值 显示在右上角 mAP 值越高 说明正确的预测越多
我们模型的 mAP 总体看来不错 我会在一些样本上预览该模型 确保模型预测看上去一切正常
看上去一切都很好 我现在可以把这个模型 放进我们的 app 里了
借助你们刚刚看到的扩充功能 利用 Create ML 创建对象检测模型非常简单
Create ML 可以通过对训练 提供更多的控制 帮助你自定义模型 它可以帮助你用较少的数据 和更小的输出尺寸构建精确的模型 我们非常期待你们利用这些全新的功能 呈现精彩的想法
-
-
正在查找特定内容?在上方输入一个主题,就能直接跳转到相应的精彩内容。
提交你查询的内容时出现错误。请检查互联网连接,然后再试一次。