大多数浏览器和
Developer App 均支持流媒体播放。
-
基于 Apple GPU 训练机器学习和 AI 模型
了解如何借助适用于 PyTorch、JAX 和 TensorFlow 的 Metal 工具,在 Apple 芯片上训练模型。充分利用新的注意力操作和量化支持,提升设备端 Transformer 模型性能。
章节
- 0:00 - Introduction
- 1:36 - Training frameworks on Apple silicon
- 4:16 - PyTorch improvements
- 11:26 - ExecuTorch
- 13:19 - JAX features
资源
相关视频
WWDC24
WWDC23
WWDC22
WWDC21
-
下载
大家好 我叫 Yona Havocainen 是 GPU, Graphics and Display Software 团队的软件工程师 今天 我将讲解如何 基于 Apple 芯片 GPU 训练机器学习和 AI 模型 并介绍今年推出的新功能
Apple 芯片提供了 许多令人惊叹的功能 用于在设备上实现机器学习 强大的 GPU 能够出色地 完成优化现代神经网络 所需的各类计算
GPU 与统一内存架构相结合 可以直接访问大量内存
借助大量内存 你可以在设备上 本地训练和运行大型模型
或者在训练期间增加批次大小 这通常可以加快收敛速度
另外 由于无需在多台 机器上分配模型权重 从训练到部署的过程变得更加简单
训练是在 Apple 平台上 部署模型的第一步 模型完成训练后 必须做好 在设备上进行部署的准备
准备就绪后 就可以 将模型集成到应用程序中了
如果你还没有看过 关于部署机器学习模型的 整体流程的讲座 请观看有关 Apple 设备上 机器学习工作流程的视频
在这个讲座中 我将重点介绍训练 并展示一些能够 利用 Apple 芯片的 独特计算能力的框架
要使用功能强大的 GPU 你可以在某个热门机器学习 框架中使用 Metal 后端 这些框架包括 TensorFlow、 PyTorch、JAX 和 MLX
TensorFlow 是适用于许多 行业应用程序的可靠框架
Metal 后端支持各种功能 例如适用于大型项目的 分布式训练 以及可提高训练性能的混合精度 为 TensorFlow 启用 Metal 后端变得比以往更轻松 只需使用 Pip 等软件包管理器 安装 TensorFlow 然后将它导入项目中
请观看 WWDC21 上的相关视频 进一步了解 TensorFlow Metal 后端
另一个广泛使用的 框架是 PyTorch Metal 后端支持 自定运算和剖析等功能 可轻松针对网络性能 进行基准评测和改进 在 PyTorch 中开始使用 Metal 后端也很简单 只需安装 PyTorch 并将它导入项目 然后将默认设备设置为 mps
如需进一步了解 PyTorch、Metal 后端 请观看 WWDC22 上的相关视频
JAX 是最近新增的通过 Metal 后端支持的框架
它支持的功能包括即时编译 以及用于抹掉数据的 类似 Numpy 的界面等
只需安装 JAX Metal 并将 JAX 导入项目中 即可开始使用 JAX Metal 后端
WWDC23 上的相关视频详细介绍了 适用于 JAX 的 Metal 后端
MLX 是通过 Metal 后端 支持的最新框架
MLX 专门针对 Apple 芯片 进行了设计和优化 它以原生方式支持 类似 Numpy 的 API、 即时编译、 分布式训练和统一内存等功能
这个框架支持 Python、 Swift、C 和 C++ 绑定
你可以在我们的代码 存储库中找到运行各种 机器学习任务的示例 例如微调转换器模型、 图像生成和音频转写等
开始使用 MLX 和使用其他框架一样简单 只需在 Python 环境中 安装 wheel 并将它导入项目中
如需进一步了解 MLX 框架 请查看我们的文档和代码存储库
现在我们有了用于训练的 Apple 芯片主要工具 接下来我将开始介绍今天的主题 我想展示一些特别针对 其中两个框架推出的 新功能和改进功能 这两个框架是 PyTorch 和 JAX
我们先来讨论 PyTorch
在一年前的 WWDC23 上 MPS 后端开发推进到了 Beta 阶段
从那时起 我们开始逐步 增加支持 包括支持自定内核、 扩大运算覆盖范围
以及实现统一内存架构 此外 性能和功能方面 也进行了大量改进和修复 其中许多内容都是由 PyTorch 相关开源社区贡献的
在这一年中 我们还扩大了 各种网络的覆盖范围 例如 在用于存储先进转换器模型的 HuggingFace 存储库中 PyTorch-MPS 后端 现在一经安装就能提高 排名前 50 的热门网络的网速 其中涵盖许多在这一年中 变得广受欢迎的模型 例如 Stable diffusion、 Meta LlaMA 模型和 Gemma
说到改进 我想特别强调三项对转换器模型 产生影响的改进功能 第一项是支持 8 位 和 4 位整数量化 可在设备内存中 运行更大规模的模型
第二项是融合了缩放点积注意力 可提高许多常见模型的性能
第三项是统一内存支持 可在将计算调度到 GPU 时 移除冗余的张量副本
接下来 我将逐一 详细说明这些主题 32 位浮点数或这里显示的 16 位浮点数等数据格式 是训练模型时的常见格式 1 位表示值的符号 5 位表示指数 10 位表示小数 在训练期间 这种精度 在更新参数时非常有用 训练结束后 可以使用 一种称为量化的技术 来减少参数所需的内存
通过将相同的值表示为 8 位整数 可以将所需内存减少一半 这样做的好处包括 减少模型的内存占用、 提高用于计算的吞吐量 并且根据具体的模型 只需略微降低或不降低 输出精度即可实现这种效果
缩放点积注意力 是许多转换器模型的核心 这个运算一开始 需要输入标记化文本
输入被拆分为三个张量 分别称为查询张量、 键张量和值张量
然后通过一系列矩阵乘法、 缩放和 Softmax 运算 对这三个张量进行处理 通过将一系列运算 融合到单个内核调用中 可以避免将许多小型计算 调度到 GPU 的开销 从而提高许多网络的整体性能
我想强调的最后一个性能更新是 Apple 设备上的统一 内存架构带来的好处 它让张量可以直接存在于主内存中 并可供 CPU 和 GPU 访问 而无需将位从内存的 一个区域拷贝到另一个区域
接下来 我将展示一个 端到端工作流程 来总结关于 PyTorch 的讨论 这个流程展示了如何采用 语言模型、自定模型、针对用例 微调模型并在设备上运行模型
首先导入 torch 并锁定我的随机席位 以获得可重现的结果
我使用热门转换器库 来下载并设置模型和标记器 它让我可以便捷地 从 HuggingFace 存储库中 获取模型
我使用具有 30 亿个参数的 Open LLaMA2 作为任务的基础模型 我还需要模型训练时 使用的相应标记器
为了将微调适配器附加到模型 我将使用包含 lora_config 的 peft 库 我将定义适配器的参数 然后使用基础模型和配置 来创建新的 PeftModel
现在 我已经准备好将模型 发送到计算设备 MPS
接下来 需要选择 用于进行微调的数据 我将使用 Andrej Karpathy 提供的 tinyshakespeare 数据集 作为训练输入 数据集中包含 Shakespeare 的 多个作品 方便地串联成单个文件
数据集载入完成后 我可以 将它载入到数据集对象中 并指示应该针对数据 使用哪个标记器
我需要设置一些 训练参数来进行微调 我可以使用 Trainer 类 并设置批次大小以及 要使用的训练周期数量等参数
data_collator 对象 会为我们的训练器对象 形成训练批次
现在 我可以传递模型、 参数、数据整理器 和训练数据集来创建训练器对象
在开始训练之前 我先检查一下模型输出 然后再进行微调 我将添加一个小巧便捷的函数 利用它来接受输入文本、 将文本标记化以供模型使用、 生成输出结果 并将结果去标记化 恢复为可直接识别的文本
我用莎士比亚作品中的 某句话进行测试 看看模型会生成哪种回复
在微调之前 这个模型在我看来 就像某种字典条目 它一开始正确显示了引言的出处 然后就漫无目的地讨论 关于一家之主的话题
与字典对话很无聊 所以我们 来看看能不能通过微调 让模型变得生动起来
我首先使用 trainer 类 启动训练 它会使用我在前面定义的 参数来“消化”数据集
过了一会儿 在使用数据集 运行了 10 个周期后 训练就结束了
现在 我来尝试为它提供 和之前相同的输入
它引用了一段来自 Menenius 的话 内容很有趣 显然我们的微调取得了成效
现在 我来保存模型以供将来使用 我将适配器和基础模型 合并成一个实体 这样使用起来更方便 另外 确保将标记器 与模型一起存储
现在模型已经训练完毕 我想把它部署到设备上试用一下
对于大多数网络 首选部署方式是 使用 CoreML 进行部署
我们关于在设备上部署模型的 讲座中详细讨论了这一点
在这个示例中 我希望始终在 PyTorch 生态系统内操作 以便使用新的 ExecuTorch 框架来部署模型
ExecuTorch 用于在各种 设备上部署 PyTorch 模型 以进行推理 在 ExecuTorch 上部署时 你可以无缝使用 你在 PyTorch 训练中 定义的任何自定运算
ExecuTorch 使用 MPS 分区器来分析计算图形 并利用 MPS 设备 加快模式识别速度
这里显示了在本地设备上 设置 ExecuTorch 的方法
首先 将存储库克隆到机器上
然后更新子模块
最后 运行安装脚本 用于在构建 ExecuTorch 时 传递使用 MPS 绑定的选项 接下来我来展示一下如何 在 ExecuTorch 中部署模型 我将沿用 ExecuTorch 存储库中的示例 我使用 Meta LLaMA2 模型 作为测试模型 系统已使用分组量化方法 将这个模型转换为 4 位整数 数据类型 这让模型的规模变小 运行起来也更快
在 macOS 上 我将在存储库中 为 iOS 构建演示 App 并将 iPad Pro 作为部署目标
App 构建完毕后 我将选择要使用的模型 以及用于训练模型的相应标记器
接下来 我将问模型 是否能告诉我千层面的做法
我询问时使用的是 LLaMA2 提示模板 由于这个模型经过了微调 可以像聊天机器人一样运行 它预计会以这种形式进行回复
通过 ExecuTorch 在 iPad 上本地运行后 模型针对晚餐给出了 一些不错的建议 不过我想我会在食材中 再加一些西红柿和芝士
以上都是你可以使用 新功能和改进功能来加快 PyTorch 工作流程的方法 接下来 我将介绍我们向 MPS Graph 支持的 另一个热门机器学习框架 JAX 添加的新功能
去年在 WWDC23 上发布了 面向开发者的 JAX-Metal 插件 从那时起 这个插件不断得到改进 并增加了更多功能 进行了 与性能相关的更多更新
这些变更包括 改进了高级数组索引、
采用了官方 JAX 存储库中的 CI runner 工作流程
以及支持混合精度等
我想重点介绍一下 自发布以来采用 JAX-Metal 后端的几个平台 首先是 MuJoCo 这是一个开源框架 适用于需要快速进行 精确模拟的用例 例如机器人和生物力学
借助 JAX-Metal 后端 这个框架 能够为在 Mac 平台上 运行的用户提供最佳性能
其次是 AXLearn 这是一个用于开发 大规模深度学习模型的库 借助 Metal 后端 可以在本地设备上 快速交易和测试工作流程
你可以查看这两个库并测试一下 在 Mac 设备上使用这两个库时 JAX-Metal 后端如何带来 出色的使用体验
接下来 让我们更详细地了解一下 JAX-Metal 后端 新增的一些改进功能 我将介绍 JAX 中的混合精度、 NDArray 索引 和内边距
今年的一项更新是 JAX-Metal 框架现在 支持 BFloat16 数据类型
这种数据类型表示 浮点值的广泛动态范围 适用于混合精度训练等用例
在 JAX 中 可以像使用任何其他 数据类型一样使用这种新数据类型
例如 你可以使用 新数据类型创建张量
另一项改进是 借助 NDArray 索引和更新支持 你可以使用类似 Numpy 的语法来操作数组
例如 如果你创建了一个 包含两行和两列的小型数组 你可以使用 Numpy 索引语法 将第一列除以 10
JAX 可让你定义内边距策略 JAX-Metal 后端现在 支持这些内边距策略
你可以使用 JAX 在元素之间 添加内边距 也称为膨胀
为此 需要调用 pad 函数 它对于每个维度接受三个参数
你还可以使用负数内边距 从张量中移除元素
只需在内边距配置中 传递负值即可
接下来是关于 JAX 的最后一部分 我想展示一个使用 JAX 的小示例 我将使用之前讨论的 AXLearn 库 我将从库中选择参数为 fuji-7B 的模型来运行 并针对模型使用之前 讨论过的 BFloat16 数据类型
这个脚本将创建一个小型随机输入 来传递给模型 并要求 模型生成后续词元
输出结果会显示 logit 和输出词元
预测完成后 我将重新运行相同的脚本 但这次会使用环境变量 将 JAX 设置为在 CPU 上运行
在这里可以看到 CPU 输出 与 Metal 后端推理的输出一致 演示到此结束
关于 JAX 以及本次 WWDC 讲座的演示 也到此结束了 总结一下我们今天讨论的内容
Apple 芯片拥有 统一内存架构的优势 为许多机器学习任务 带来了很大的好处 它允许使用较大的模型 和较大的批次大小 也无需在 CPU 和 GPU 之间拷贝数据 因为这两者都可以访问相同的内存
借助我们为 PyTorch、JAX、 TensorFlow 和 MLX 等 热门框架提供的 Metal 后端 你可以使用 功能强大的 Apple 芯片 GPU 今年 我们针对新性能 进行了大量改进 以便为非常热门的 转换器类模型提供支持
你可以充分利用这些更新 只需确保你使用的是 最新版本的框架 并记得更新 macOS
感谢观看 我们迫不及待想看到你 使用这些新功能创造的种种精彩
-
-
正在查找特定内容?在上方输入一个主题,就能直接跳转到相应的精彩内容。
提交你查询的内容时出现错误。请检查互联网连接,然后再试一次。