mindspore.nn.DistributedGradReducer
- class mindspore.nn.DistributedGradReducer(parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP)[源代码]
分布式优化器。
用于数据并行模式中,对所有卡的梯度利用AllReduce进行聚合。
- 参数:
parameters (list) - 需要更新的参数。
mean (bool) - 当mean为True时,对AllReduce之后的梯度求均值。未指定时,使用auto_paralel_context中的配置“gradients_mean”。 默认值:None。
degree (int) - 平均系数,通常等于设备编号。默认值:None。
fusion_type (int) - AllReduce算子的融合类型。默认值:1。
group (str) - AllReduce算子的通信域,若需要自定义通信域,需要调用create_group接口。默认值:GlobalComm.WORLD_COMM_GROUP。
- 异常:
ValueError - 如果degree不是int或小于0。
- 支持平台:
Ascend
GPU
样例:
说明
运行以下样例之前,需要配置好通信环境变量。
针对Ascend设备,用户需要准备rank表,设置rank_id和device_id,详见 Ascend指导文档 。
针对GPU设备,用户需要准备host文件和mpi,详见 GPU指导文档 。
该样例需要在多卡环境下运行。
>>> import numpy as np >>> import mindspore as ms >>> from mindspore.communication import init >>> from mindspore import ops >>> from mindspore import Parameter, Tensor >>> from mindspore import nn >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> ms.reset_auto_parallel_context() >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL) >>> >>> class TrainingWrapper(nn.Cell): ... def __init__(self, network, optimizer, sens=1.0): ... super(TrainingWrapper, self).__init__(auto_prefix=False) ... self.network = network ... self.network.add_flags(defer_inline=True) ... self.weights = optimizer.parameters ... self.optimizer = optimizer ... self.grad = ops.GradOperation(get_by_list=True, sens_param=True) ... self.sens = sens ... self.reducer_flag = False ... self.grad_reducer = None ... self.parallel_mode = context.get_auto_parallel_context("parallel_mode") ... self.depend = ops.Depend() ... if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: ... self.reducer_flag = True ... if self.reducer_flag: ... mean = context.get_auto_parallel_context("gradients_mean") ... degree = context.get_auto_parallel_context("device_num") ... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) ... ... def construct(self, *args): ... weights = self.weights ... loss = self.network(*args) ... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) ... grads = self.grad(self.network, weights)(*args, sens) ... if self.reducer_flag: ... # apply grad reducer on grads ... grads = self.grad_reducer(grads) ... return self.depend(loss, self.optimizer(grads)) >>> >>> class Net(nn.Cell): ... def __init__(self, in_features, out_features): ... super(Net, self).__init__() ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), ... name='weight') ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... output = self.matmul(x, self.weight) ... return output >>> >>> size, in_features, out_features = 16, 16, 10 >>> network = Net(in_features, out_features) >>> loss = nn.MSELoss() >>> net_with_loss = nn.WithLossCell(network, loss) >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) >>> train_cell = TrainingWrapper(net_with_loss, optimizer) >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) >>> grads = train_cell(inputs, label) >>> print(grads) 256.0