mindspore.nn.PipelineGradReducer
- class mindspore.nn.PipelineGradReducer(parameters, scale_sense=1.0)[source]
PipelineGradReducer is a gradient reducer for pipeline parallelism.
- Parameters
- Raises
RuntimeError – If the mode is not graph mode.
RuntimeError – If the parallel mode is not semi auto parallel or auto parallel.
- Supported Platforms:
Ascend
GPU
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. Please see the rank table Startup for more details.
For the GPU devices, users need to prepare the host file and mpi, please see the mpirun Startup .
This example should be run with multiple devices.
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn, ops, Tensor >>> from mindspore.communication import init >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> ms.reset_auto_parallel_context() >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) >>> init() >>> ms.set_seed(1) >>> >>> class Network(nn.Cell): ... def __init__(self, in_features, out_features, sens=1.0): ... super().__init__() ... self.layer1 = nn.Dense(in_features, 16) ... self.relu1 = nn.ReLU() ... self.layer2 = nn.Dense(16, 16) ... self.relu2 = nn.ReLU() ... self.layer3 = nn.Dense(16, out_features) ... ... def construct(self, x): ... x = self.layer1(x) ... x = self.relu1(x) ... x = self.layer2(x) ... x = self.relu2(x) ... logits = self.layer3(x) ... return logits >>> >>> size, in_features, out_features = 16, 32, 10 >>> net = Network(in_features, out_features) >>> net.layer1.pipeline_stage = 0 >>> net.relu1.pipeline_stage = 0 >>> net.layer2.pipeline_stage = 0 >>> net.relu2.pipeline_stage = 1 >>> net.layer3.pipeline_stage = 1 >>> loss_fn = nn.CrossEntropyLoss() >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 2) >>> net_with_loss.set_train() >>> def forward_fn(inputs, target): ... loss = net_with_loss(inputs, target) ... return loss >>> >>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params()) >>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters) >>> >>> @ms.jit >>> def train_one_step(inputs, target): ... loss, grads = grad_fn(inputs, target) ... grads = pp_grad_reducer(grads) ... optimizer(grads) ... return loss, grads >>> >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) >>> label = Tensor(np.ones([size, out_features]).astype(np.float32)) >>> loss, _ = train_one_step(inputs, label) >>> print(loss) 46.36721