{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 分布式并行\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0/zh_cn/design/mindspore_distributed_training_design.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0/zh_cn/design/mindspore_distributed_training_design.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/source_zh_cn/design/distributed_training_design.ipynb)\n", "\n", "## 背景\n", "\n", "随着深度学习的快速发展,为了提升神经网络的精度和泛化能力,数据集和参数量都在呈指数级向上攀升。分布式并行训练成为一种解决超大规模网络性能瓶颈的发展趋势。\n", "\n", "为了应对数据集过大的问题,MindSpore引入了数据并行模式,利用多个设备的计算资源,同时处理更多的训练数据,加快模型训练速度。同时当数据过大或模型过大无法在单个计算节点上加载训练时,需要引入模型并行,每个计算节点只需要加载部分模型和数据,这样可以减少内存占用,提高训练效率。在分布式并行编程范式的演进中,传统的手动并行中,用户需要基于通信原语通过编码,手动把模型切分到多个节点上并行,用户需要感知图切分、算子切分、集群拓扑,才能实现最优性能。此种编程范式对于工程师存在一定的门槛要求,于是演进出了半自动并行:并行逻辑和算法逻辑解耦,用户按单卡串行的方式写算法代码,并行逻辑作为算法配置。用户只需要配置并行策略实现自动并行切分,无需额外编写代码;用户无需感知模型切片的调度及集群拓扑。全自动并行训练编程范式则更进一步,用户只需要写单卡串行算法,通过搜索算法来自动生成较优的切分策略。\n", "\n", "MindSpore通过集合通信的方式来实现并行训练过程中的数据通信和同步操作,在Ascend芯片上它依赖于华为集合通信库HCCL,在GPU上它依赖于英伟达集合通信库NCCL。MindSpore目前采用的是同步训练模式,同步模式能够保证所有设备上的参数保持一致,在每个训练迭代开始前所有设备上的参数都被同步。\n", "\n", "本篇设计文档将会集中介绍几种并行训练方式的设计原理,同时指导用户进行自定义开发。\n", "\n", "## 数据并行\n", "\n", "这个小节介绍了在MindSpore中`ParallelMode.DATA_PARALLEL`数据并行模式是如何工作的。\n", "\n", "### 数据并行原理\n", "\n", "![数据并行图解](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/data_parallel.png)\n", "\n", "1. 环境依赖\n", "\n", " 每次开始进行并行训练前,通过调用`mindspore.communication.init`接口初始化通信资源,并自动创建全局通信组`WORLD_COMM_GROUP`。\n", "\n", "2. 数据分发(Data distribution)\n", "\n", " 数据并行的核心在于将数据集在样本维度拆分并下发到不同的卡上。在`mindspore.dataset`模块提供的所有数据集加载接口中都有`num_shards`和`shard_id`两个参数,它们用于将数据集拆分为多份并循环采样的方式,采集`batch`大小的数据到各自的卡上,当出现数据量不足的情况时将会从头开始采样。\n", "\n", "3. 网络构图\n", "\n", " 数据并行网络的书写方式与单机网络没有差别,这是因为在正反向传播(Forward propagation & Backward Propagation)过程中各卡的模型间是独立执行的,只是保持了相同的网络结构。唯一需要特别注意的是为了保证各卡间训练同步,相应的网络参数初始化值应当是一致的,在`DATA_PRALLEL`和`HYBRID_PARALLEL`模式下建议通过使能`parameter_broadcast`达到权重广播的目的;在`AUTO_PRALLEL`和`SEMI_AUTO_PARALLEL`模式下,框架内部会自动分析参数的并行度,并设置相应的随机数种子,保证在数据并行维度的设备上参数初始化值一致。\n", "\n", "4. 梯度聚合(Gradient aggregation)\n", "\n", " 数据并行理论上应该实现和单机一致的训练效果,为了保证计算逻辑的一致性,在梯度计算完成后插入`AllReduce`算子实现各卡间的梯度聚合操作。MindSpore设置了`mean`开关,用户可以选择是否要对求和后的梯度值进行求平均操作,也可以将其视为超参项,打开开关等价于学习率倍数缩小。\n", "\n", "5. 参数更新(Parameter update)\n", "\n", " 因为引入了梯度聚合操作,所以各卡的模型会以相同的梯度值一起进入参数更新步骤。因此MindSpore实现的是一种同步数据并行训练方式。理论上最终每卡训练出来的模型是相同的,如果网络中含有在样本维度的归约类型操作,网络的输出可能会有所差别,这是由数据并行的切分性质决定的。\n", "\n", "### 数据并行代码\n", "\n", "1. 集合通信\n", "\n", " - [management.py](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/python/mindspore/communication/management.py):这个文件中涵盖了集合通信过程中常用的`helper`函数接口,例如获取集群数量和卡的序号等。当在Ascend芯片上执行时,框架会加载环境上的`libhccl.so`库文件,通过它来完成从Python层到底层的通信接口调用。\n", " - [comm_ops.py](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/python/mindspore/ops/operations/comm_ops.py):MindSpore将支持的集合通信操作都封装为算子的形式放在这个文件下,包括`AllReduce`、`AllGather`、`ReduceScatter`和`Broadcast`等。`PrimitiveWithInfer`中除了定义算子所需属性外,还包括构图过程中输入到输出的`shape`和`dtype`推导。\n", "\n", "2. 梯度聚合\n", "\n", " - [grad_reducer.py](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/python/mindspore/nn/wrap/grad_reducer.py):这个文件实现了梯度聚合的过程。对入参`grads`用`HyperMap`展开后插入`AllReduce`算子,这里采用的是全局通信组,用户也可以根据自己网络的需求仿照这个模块进行自定义开发。MindSpore中单机和分布式执行共用一套网络封装接口,在`Cell`内部通过`ParallelMode`来区分是否要对梯度做聚合操作。\n", "\n", "## 半自动并行\n", "\n", "这个小节介绍了在MindSpore中`ParallelMode.SEMI_AUTO_PARALLEL`半自动并行模式是如何工作的。\n", "\n", "### 半自动并行原理\n", "\n", "![自动并行图解](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/auto_parallel.png)\n", "\n", "1. 分布式算子和张量排布模型\n", "\n", " 在上面的架构图中,自动并行流程会对单机的正向计算图(ANF Graph)进行遍历,以分布式算子(Distributed Operator)为单位对张量进行切分建模,表示一个算子的输入输出张量如何分布到集群各个卡上(Tensor Layout)。这种模型充分地表达了张量和设备间的映射关系,用户无需感知模型各切片放到哪个设备上运行,框架会自动调度分配。\n", "\n", " 为了得到张量的排布模型,每个算子都具有切分策略(Shard Strategy),它表示算子的各个输入在相应维度的切分情况。通常情况下只要满足以2为基、均匀分配的原则,张量的任意维度均可切分。以下图为例,这是一个三维矩阵乘(BatchMatMul)操作,它的切分策略由两个元组构成,分别表示`input`和`weight`的切分形式。其中元组中的元素与张量维度一一对应,`2^N`为切分份数,`1`表示不切。当用户想表示一个数据并行切分策略时,即`input`的`batch`维度切分,其他维度不切,可以表达为`strategy=((2^N, 1, 1),(1, 1, 1))`;当表示一个模型并行切分策略时,即`weight`的非`batch`维度切分,这里以`channel`维度切分为例,其他维度不切,可以表达为`strategy=((1, 1, 1),(1, 1, 2^N))`;当表示一个混合并行切分策略时,其中一种切分策略为`strategy=((2^N, 1, 1),(1, 1, 2^N))`。\n", "\n", " ![算子切分定义](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/operator_split.png)\n", "\n", " 依据切分策略,分布式算子中定义了推导算子输入张量和输出张量的排布模型的方法。这个排布模型由`device_matrix`,`tensor_shape`和`tensor map`组成,分别表示设备矩阵形状、张量形状、设备和张量维度间的映射关系。分布式算子会进一步根据张量排布模型判断是否要在图中插入额外的计算、通信操作,以保证算子运算逻辑正确。\n", "\n", "2. 张量排布变换\n", "\n", " 当前一个算子的输出张量模型和后一个算子的输入张量模型不一致时,就需要引入计算、通信操作的方式实现张量排布间的变化。自动并行流程引入了张量重排布算法(Tensor Redistribution),可以推导得到任意排布的张量间通信转换方式。下面三个样例表示公式`Z=(X×W)×V`的并行计算过程,即两个二维矩阵乘操作,体现了不同并行方式间如何转换。\n", " 在样例一中,第一个数据并行矩阵乘的输出在行方向上存在切分,而第二个模型并行矩阵乘的输入需要全量张量,框架将会自动插入`AllGather`算子实现排布变换。\n", "\n", " ![tensor-redistribution1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/tensor_redistribution1.png)\n", "\n", " 在样例二中,第一个模型并行矩阵乘的输出在列方向上存在切分,而第二个数据并行矩阵乘的输入在行方向上存在切分,框架将会自动插入等价于集合通信中`AlltoAll`操作的通信算子实现排布变换。\n", "\n", " ![tensor-redistribution2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/tensor_redistribution2.png)\n", "\n", " 在样例三中,第一个混合并行矩阵乘的输出切分方式和第二个混合并行矩阵乘的输入切分方式一致,所以不需要引入重排布变换。但由于第二个矩阵乘操作中,两个输入的相关维度存在切分,所以需要插入`AllReduce`算子保证运算正确性。\n", "\n", " ![tensor-redistribution3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/tensor_redistribution3.png)\n", "\n", " 综上,1、2两点是自动并行实现的基础,总体来说这种分布式表达打破了数据并行和模型并行的边界,轻松实现混合并行。从脚本层面上,用户仅需构造单机网络,即可表达并行算法逻辑,框架将自动实现对整图切分。\n", "\n", "3. 分布式自动微分\n", "\n", " 传统的手动模型切分除了需要关注正向网络通信还需要考虑网络反向的并行运算,MindSpore通过将通信操作包装为算子,并利用框架原有的自动微分操作自动生成通信算子反向,所以即便在进行分布式训练时,用户同样只需关注网络的前向传播,真正实现训练的全自动并行。\n", "\n", "4. 支持多维混合并行\n", "\n", " 半自动并行支持多种并行模式的自动混合使用,分别有:\n", "\n", " **算子级并行**:算子级并行以神经网络中的算子为单位,将输入张量切分到多个设备上进行计算。通过这种方式,可以实现数据样本和模型参数在不同设备之间的分配,从而训练大规模的深度学习模型,并利用集群资源进行并行计算,提高整体速度。用户可以设置每个算子的切分策略,框架会根据算子的切分策略对每个算子及其输入张量进行切分建模,以保持数学等价性。这种方法可以有效地减少单个设备的负载,提高计算效率,适用于大规模深度神经网络的训练。详情参考:[算子级并行](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/parallel/operator_parallel.html)\n", "\n", " **流水线并行**:当集群设备数很多时,如果仅采用算子级并行的方式,则需要在整个集群的通信域上进行通信,这可能使得通信效率低,从而降低整体性能。而流水线并行能将神经网络结构切分成多个stage,每个stage跑在一部分设备内,将集合通信的通信域限定在这部分设备范围内,而stage间采用点对点通信。流水线并行的优点在于:能提升通信效率、能方便的处理按层堆叠的神经网络结构。缺点在于:同一时刻内,有些节点可能处于空闲状态。详情参考:[流水线并行](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/parallel/pipeline_parallel.html)\n", "\n", " **MoE并行**:MoE是将专家分布到不同的worker上,并且每个worker承担不同批次的训练数据。对于非MoE层来说,专家并行和数据并行一样。在MoE层中,序列中的token通过all-to-all通信被发送到它们相匹配的专家所对应的worker。在完成对应专家的计算后,再通过all-to-all重新传回到原来的worker,组织成原始序列,用于下一层的计算。由于MoE模型通常有大量的专家,专家并行度比模型并行度更能随模型规模的增大而增大。\n", "\n", " ![MoE并行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/MoE.png)\n", "\n", " **多副本并行**:将输入模型的数据按照batchsize维度进行切分,从而将现有的单副本形式修改成多副本的形式,使其底层在通信的时候,另一副本进行计算操作,无需等待,这样就能保证多副本的计算和通信的时间相互互补,提升模型性能,同时将数据拆成多副本的形式还能减少算子输入的参数量,从而减少单个算子的计算时间,对提升模型性能有很大帮助。\n", "\n", " ![多副本并行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/multi_copy.png)\n", "\n", " **优化器并行**:在数据并行或算子级并行训练时,模型的参数可能在多个设备上存在同一份副本。这使得优化器在更新该权重之时,在多个设备间存在冗余计算。在此情况下,可以通过优化器并行将优化器的计算量分散到多个设备上。它的优点在于:能减少静态内存消耗、减少优化器内的计算量。缺点在于:增加了通信开销。详情参考:[优化器并行](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/parallel/optimizer_parallel.html)\n", "\n", "### 半自动并行代码\n", "\n", "1. 张量排布模型\n", " - [tensor_layout](https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/ccsrc/frontend/parallel/tensor_layout):这个目录下包含了张量排布模型相关功能的定义及实现。其中`tensor_layout.h`中声明了一个张量排布模型需要具备的成员变量`tensor_map_origin_`,`tensor_shape_`和`device_arrangement_`等。在`tensor_redistribution.h`中声明了实现张量排布间`from_origin_`和`to_origin_`变换的相关方法,将推导得到的重排布操作保存在`operator_list_`中返回,并计算得到重排布所需的通信开销`comm_cost_`, 内存开销`memory_cost_`及计算开销`computation_cost_`。\n", "\n", "2. 分布式算子\n", " - [ops_info](https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/ccsrc/frontend/parallel/ops_info):这个目录下包含了分布式算子的具体实现。在`operator_info.h`中定义了分布式算子实现的基类`OperatorInfo`,开发一个分布式算子需要继承于这个基类并显式实现相关的虚函数。其中`InferTensorInfo`,`InferTensorMap`和`InferDevMatrixShape`函数定义了推导该算子输入、输出张量排布模型的算法。`InferForwardCommunication`,`InferMirrorOps`等函数定义了切分该算子需要插入的额外计算、通信操作。`CheckStrategy`和`GenerateStrategies`函数定义了算子切分策略校验和生成。根据切分策略`SetCostUnderStrategy`将会产生该策略下分布式算子的并行开销值`operator_cost_`。\n", "\n", "3. 设备管理\n", " - [device_manager.h](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/ccsrc/frontend/parallel/device_manager.h):这个文件实现了集群设备通信组的创建及管理。其中设备矩阵模型由`device_matrix.h`定义,通信域由`group_manager.h`管理。\n", "\n", "4. 整图切分\n", " - [step_auto_parallel.h](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h), [step_parallel.h](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/ccsrc/frontend/parallel/step_parallel.h):这两个文件包含了自动并行流程的核心实现。首先由`step_auto_parallel.h`调用策略搜索流程并产生分布式算子的`OperatorInfo`,然后在`step_parallel.h`中处理算子切分和张量重排布等流程,对单机计算图进行分布式改造。\n", "\n", "## 全自动并行\n", "\n", "半自动并行将用户从复杂的分布式代码开发中解放出来,大大减轻了用户开发分布式AI大模型的难度。尽管用户不再需要考虑设备间的数据存储和通信,但是仍然需要用户为每个算子指定合适的切分策略,因为不同的切分策略的训练性能相差很大。用户仍然需要具备相应的并行知识并根据网络结构、集群拓扑等计算分析,才能在巨大的搜索空间中定义合适的并行策略。而现实情况是AI框架的主要用户是AI研究员和工程师,恰恰不一定具备专业的并行知识。另一方面,面对巨大的搜索空间,为大模型找到合适的并行策略需要月级人工调优成本,且仍然不能保证策略最优。例如DeepSpeed、Megatron等针对transformer类的网络的专家定制策略,仍然需要用户定义dp、mp、pp等配置,更何况网络模型的结构不止transformer一种。基于以上两个原因,MindSpore提供了多种自动混合并行策略生成方案,尽量减轻用户对于并行配置的感知,让用户能够快速、高效、容易地训练大模型。\n", "\n", "这个小节介绍了在MindSpore中`ParallelMode.AUTO_PARALLEL`全自动并行模式是如何工作的。\n", "\n", "### 特性设计\n", "\n", "全自动并行是基于MindSpore半自动框架,以自动混合并行策略生成算法代替专家配置并行策略。下图展示了使用MindSpore分布式训练或推理一个神经网络的过程,用户使用Python语言开发自己的神经网络模型(或MindIR导入),经MindSpore解析成计算图(ANF图),自动混合并行策略生成模块通过算法搜索到较优的策略,传递给半自动并行模块,经过半自动模块分析张量排布,分布式算子分析,设备管理以及进行整图切分等操作,传递给后端进行计算。\n", "\n", "实际上,混合并行策略生成模块负责在给定神经网络模型和集群配置下,来找到适合的并行切分策略。所采用的关键技术是基于代价模型的策略搜索算法,即构建代价模型(Cost Model)来描述在分布式训练场景下所产生的计算代价(Computation Cost)与通信代价(Communication Cost),以内存开销(Memory Cost)为约束条件,通过计算图搜索算法,高效地搜索出性能较优地并行策略。\n", "\n", "![全自动并行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/auto.png)\n", "\n", "### 三种搜索算法\n", "\n", "全自动并行的实现难度非常大,MindSpore根据需要用户介入的程度,将提供的策略生成算法分成了L1级别和L2级别(此处我们假设手工配置全图策略SEMI_AUTO为L0级别,完全不需要用户参与的方案为L3级别)。\n", "\n", "L1级别的策略生成算法叫做策略广播(Sharding Propagation),在该种模式下,用户仅需要手工定义几个关键算子的策略,计算图中其余算子的策略由算法自动生成。因为关键算子的策略已被定义,该算法的cost model主要描述的算子之间的重排布代价(Redistribution Cost),优化目标为全图重排代价最小。因为已经定义了主要算子策略,相当于认为压缩了搜索空间,这种方案的搜索时间较短,其策略性能依赖于关键算子策略的定义,因此仍然要求用户具备分析定义策略的能力。详情参考:[切分策略传播](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/parallel/sharding_propagation.html)\n", "\n", "L2级别的策略生成算法有两种,分别是动态规划算法(Dynamic Programming)和符号化自动策略生成(Symbolic Automatic Parallel Planner 缩写SAPP)。两种方法各有优劣,动态规划算法能够搜索出代价模型刻画的最优策略,但是在搜索巨大网络的并行策略时耗时较长。而SAPP算法能够对于巨大网络以及大规模切分瞬间生成最优策略。\n", "动态规划算法的核心思路是建立全图的代价模型,包括计算代价和通信代价,来描述分布式训练过程中的绝对时延,使用边消除和点消除等等价方法压缩搜索时间,但是搜索空间随着设备数和算子数的增加实际上是指数级增长的,因此对于大模型大集群来说效率不高。\n", "SAPP基于并行原理建模,通过建立抽象机来描述硬件集群拓扑,通过符号化简优化代价模型。其代价模型比较的不是预估的绝对时延,而是不同并行策略的相对代价,因此能够大大压缩搜索空间,对于百卡集群能够保证分钟级的搜索时间。详情参考:[分布式并行总览](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/parallel/introduction.html)\n", "\n", "Sharding Propagation和SAPP目前支持手工定义Pipeline+自动算子级并行,且可与重计算、优化器并行等优化共同使用。Dynamic Programming算法仅支持算子级自动并行。\n", "\n", "### 全自动并行代码\n", "\n", "**策略搜索算法**:[auto_parallel](https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/ccsrc/frontend/parallel/auto_parallel)目录下实现了策略搜索的算法。`graph_costmodel.h`定义了构图信息,其中每个点表示一个算子`OperatorInfo`,有向边`edge_costmodel.h`表示算子的输入输出关系及重排布的代价。`operator_costmodel.h`中定义了每个算子的代价模型,包括计算代价、通信代价和内存代价。在`costmodel.h`中定义了cost和图操作的数据结构。\n", "\n", "- **dynamic_programming**:[dp_algo_costmodel.cc](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc)这个文件主要描述了动态规划算法的主要流程,由一系列图操作组成。\n", "- **sharding_propagation**:[graph_costmodel.cc](https://gitee.com/mindspore/mindspore/blob/r2.0/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc)这个文件实现了策略广播(Sharding Propagation),主要用BFS的遍历方法从点到面,将若干个点的策略,传播到整个图。\n", "- **symbolic_automatic_parallel_planner**:[rec_core](https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core)目录下实现了符号化自动策略生成算法(Symbolic Automatic Parallel Planner)。\n", "\n", "## 异构并行\n", "\n", "异构并行训练方法是通过分析图上算子内存占用和计算密集度,将内存消耗巨大或适合CPU逻辑处理的算子切分到CPU子图,将内存消耗较小计算密集型算子切分到硬件加速器子图,框架协同不同子图进行网络训练,使得处于不同硬件且无依赖关系的子图能够并行进行执行的过程。\n", "\n", "### 计算流程\n", "\n", "MindSpore异构并行训练典型的计算流程如下图所示:\n", "\n", "![heterogeneous-heter](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/heter.png)\n", "\n", "1. 用户设置网络执行的后端" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-01-05T02:15:56.790220Z", "start_time": "2022-01-05T02:15:55.114811Z" } }, "outputs": [], "source": [ "import mindspore as ms\n", "ms.set_context(device_target=\"GPU\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. 用户设置特定算子执行后端" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-05T09:02:10.573036Z", "start_time": "2022-01-05T09:02:09.034905Z" } }, "outputs": [], "source": [ "from mindspore import ops\n", "\n", "prim = ops.Add()\n", "\n", "prim.set_device(\"CPU\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3. 框架根据计算图算子标志进行切图\n", "4. 框架调度不同后端执行子图" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当前典型使用异构并行计算的场景有:优化器异构、Embedding异构、PS异构。\n", "\n", "### 优化器异构\n", "\n", "在盘古或GPT3大模型训练过程中,优化器状态占用了大量内存,进而限制了可训练的模型规模。使用优化器异构,将优化器指定到CPU上执行,可以极大扩展可训练模型规模:\n", "\n", "![heterogeneous-heter-opt](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/heter-opt.png)\n", "\n", "如图所示,将Adam算子配置到CPU执行同时指定加速器进行FP16计算,可以将参数内存占用降低到原始的1/3。\n", "\n", "1. 配置优化器算子到CPU执行\n", "2. 初始化FP16的权重参数以及FP32的优化器状态变量\n", "3. 将输入优化器的梯度转为FP16(如果本来就是FP16梯度,可忽略这步)\n", "4. 权重和梯度转为FP32参与优化器运算\n", "5. 更新后的FP32权重赋值给FP16的权重\n", "\n", "优化器异构代码样例如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-05T09:02:10.635821Z", "start_time": "2022-01-05T09:02:10.574494Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import mindspore as ms\n", "import mindspore.ops as ops\n", "from mindspore.common.initializer import initializer\n", "from mindspore.nn import Optimizer\n", "_adam_opt = ops.MultitypeFuncGraph(\"adam_opt\")\n", "host_assign = ops.Assign()\n", "host_assign.set_device(\"CPU\")\n", "host_cast = ops.Cast()\n", "host_cast.set_device(\"CPU\")\n", "device_cast = ops.Cast()\n", "\n", "@_adam_opt.register(\"Function\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Number\", \"Tensor\", \"Tensor\", \"Tensor\",\n", " \"Tensor\", \"Bool\", \"Bool\")\n", "def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):\n", " \"\"\"\n", " Update parameters by AdamWeightDecay op.\n", " \"\"\"\n", " success = True\n", " if optim_filter:\n", " param32 = host_cast(param, ms.float32)\n", " gradient = device_cast(gradient, ms.float32)\n", " if decay_flags:\n", " next_param = opt(param32, m, v, lr, beta1, beta2, eps, weight_decay, gradient)\n", " else:\n", " next_param = opt(param32, m, v, lr, beta1, beta2, eps, 0.0, gradient)\n", " ret = host_assign(param, host_cast(ops.depend(param32, next_param), ops.dtype(param)))\n", " return ops.depend(success, ret)\n", " return success\n", "\n", "class AdamWeightDecayOp(Optimizer):\n", " def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):\n", " super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)\n", " self.beta1 = ms.Tensor(np.array([beta1]).astype(np.float32))\n", " self.beta2 = ms.Tensor(np.array([beta2]).astype(np.float32))\n", " self.eps = ms.Tensor(np.array([eps]).astype(np.float32))\n", " self.moments1 = self.clone_param32(prefix=\"adam_m\", init='zeros')\n", " self.moments2 = self.clone_param32(prefix=\"adam_v\", init='zeros')\n", " self.opt = ops.AdamWeightDecay()\n", " self.hyper_map = ops.HyperMap()\n", " self.opt.set_device(\"CPU\")\n", "\n", " def construct(self, gradients):\n", " \"\"\"AdamWeightDecayOp\"\"\"\n", " lr = self.get_lr()\n", " if self.is_group:\n", " if self.is_group_lr:\n", " optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),\n", " lr, self.weight_decay, self.parameters, self.moments1, self.moments2,\n", " gradients, self.decay_flags, self.optim_filter)\n", " else:\n", " optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),\n", " self.weight_decay, self.parameters, self.moments1, self.moments2,\n", " gradients, self.decay_flags, self.optim_filter)\n", " else:\n", " optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,\n", " self.weight_decay), self.parameters, self.moments1, self.moments2,\n", " gradients, self.decay_flags, self.optim_filter)\n", " return optim_result\n", "\n", " def clone_param32(self, prefix, init=None):\n", " new = []\n", " for old_param in self.parameters:\n", " param_init = init\n", " if init is None:\n", " param_init = old_param.init\n", " new_state = old_param.clone()\n", " new_state.set_dtype(ms.float32)\n", " new_state.set_data(initializer(param_init, shape=old_param.shape, dtype=ms.float32))\n", " new_state.name = prefix + '.' + new_state.name\n", " new.append(new_state)\n", " return ms.ParameterTuple(new)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "步骤4、5也可以直接融合到优化器算子中做进一步优化,完整的优化器异构训练流程可以参考: \n", "\n", "### Embedding异构\n", "\n", "在一些需要查Embedding大表的网络中,Embedding表往往有上百G的规模,受加速器内存大小限制,无法直接将整表加载到加速器上执行。通过将与权重表相连的算子放到CPU上执行,避免加速器由于内存限制而无法训练网络的问题。\n", "\n", "![heterogeneous-heter-embed](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/heter-embed.png)\n", "\n", "1. 配置EmbeddingLookup算子到CPU执行" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-05T09:02:10.663460Z", "start_time": "2022-01-05T09:02:10.636839Z" } }, "outputs": [], "source": [ "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore as ms\n", "from mindspore.common.initializer import initializer\n", "class EmbeddingLookupNet(nn.Cell):\n", " def __init__(self, vocab_size, embedding_size, param_init='normal'):\n", " super(EmbeddingLookupNet, self).__init__()\n", " self.embeddinglookup = ops.EmbeddingLookup().set_device('CPU')\n", " self.embedding_table = ms.Parameter(initializer(param_init, [vocab_size, embedding_size]), name='embedding_table')\n", "\n", " def construct(self, indices):\n", " out = self.embeddinglookup(self.embedding_table, indices, 0)\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. 配置EmbeddingLookup关联稀疏优化器到CPU执行" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-05T09:02:10.680690Z", "start_time": "2022-01-05T09:02:10.665043Z" } }, "outputs": [], "source": [ "from mindspore.nn.optim import LazyAdam\n", "net = EmbeddingLookupNet(1000, 100)\n", "params = net.trainable_params()\n", "optimizer = LazyAdam(params)\n", "optimizer.target = \"CPU\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "EmbeddingLookup算子设置代码样例如下:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-01-05T09:02:10.709005Z", "start_time": "2022-01-05T09:02:10.682761Z" } }, "outputs": [], "source": [ "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore as ms\n", "from mindspore.common.initializer import initializer\n", "\n", "class EmbeddingLookup(nn.Cell):\n", " def __init__(self, vocab_size, embedding_size, param_init='normal',\n", " target='CPU', sparse=True):\n", " \"\"\"Initialize EmbeddingLookup.\"\"\"\n", " super(EmbeddingLookup, self).__init__()\n", " validator.check_value_type('sparse', sparse, [bool], self.cls_name)\n", " self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')\n", " self.target = target\n", " self.sparse = sparse\n", " if sparse:\n", " self.gatherv2 = ops.SparseGatherV2()\n", " else:\n", " self.gatherv2 = ops.Gather()\n", " self.embeddinglookup = ops.EmbeddingLookup().set_device('CPU')\n", " self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')\n", " self.embedding_table = ms.Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),\n", " name='embedding_table')\n", "\n", " def construct(self, indices):\n", " if self.target == \"CPU\":\n", " out = self.embeddinglookup(self.embedding_table, indices, 0)\n", " else:\n", " out = self.gatherv2(self.embedding_table, indices, 0)\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当前nn目录下的EmbeddingLookup、FTRL、LazyAdam等算子已经封装好异构接口,用户只需设置target属性为CPU或DEVICE即可切换执行后端。\n", "\n", "整体调用流程可以参考:\n", "\n", "### PS异构\n", "\n", "在EmbeddingTable达到T级别,单机内存无法放下时,使用Parameter Server,通过异构的Pull/Push算子进行权重的拉取和更新。\n", "\n", "![heterogeneous-heter-ps](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/docs/mindspore/source_zh_cn/design/images/heter-ps.png)\n", "\n", "Parameter Server封装异构流程,用户只需配置参数使用PS即可,具体配置流程请参考[Parameter Server训练流程](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/parallel/parameter_server_training.html)。\n", "\n", "此外,wide&deep网络中也有使用PS的流程,可参考:\n", "\n", "### 约束\n", "\n", "当前需要用户指定算子执行的后端,不支持根据网络进行自动化配置。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }