# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""parallel serialization"""
from __future__ import absolute_import
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import functional as F, composite as C, operations as P
import mindspore.common.dtype as mstype
from mindspore.common.sparse_tensor import Tensor
from mindspore.common.api import jit
from mindspore.common.parameter import Parameter
from mindspore.nn.layer import Identity
from mindspore.parallel._utils import _get_enable_parallel_optimizer
__all__ = ['PipelineGradReducer']
grad_scale = C.MultitypeFuncGraph("grad_scale")
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
accu_grad = F.depend(accu_grad, grad)
new_grad = accu_grad * reciprocal(scale)
accu_grad = F.depend(accu_grad, new_grad)
zeros = F.tensor_mul(accu_grad, 0.0)
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
return new_grad
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
new_grad = grad * reciprocal(scale)
accu_grad = F.depend(accu_grad, new_grad)
new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
return new_grad
[文档]class PipelineGradReducer(Cell):
"""
PipelineGradReducer is a gradient reducer for pipeline parallelism.
Args:
parameters (list): the parameters to be updated.
scale_sense (float): the scale sense of the gradient. Default: 1.0.
opt_shard(bool): if use parallel optimizer, set opt_shard True.
Raise:
RuntimeError: If the mode is not graph mode.
RuntimeError: If the parallel mode is not semi auto parallel or auto parallel.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
.. note::
Before running the following examples, you need to configure the communication environment variables.
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
Please see the `rank table Startup
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
for more details.
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
This example should be run with multiple devices.
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, ops, Tensor
>>> from mindspore.communication import init
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.reset_auto_parallel_context()
>>> init()
>>> ms.set_seed(1)
>>>
>>> class Network(nn.Cell):
... def __init__(self, in_features, out_features, sens=1.0):
... super().__init__()
... self.layer1 = nn.Dense(in_features, 16)
... self.relu1 = nn.ReLU()
... self.layer2 = nn.Dense(16, 16)
... self.relu2 = nn.ReLU()
... self.layer3 = nn.Dense(16, out_features)
...
... def construct(self, x):
... x = self.layer1(x)
... x = self.relu1(x)
... x = self.layer2(x)
... x = self.relu2(x)
... logits = self.layer3(x)
... return logits
>>>
>>> size, in_features, out_features = 16, 32, 10
>>> net = Network(in_features, out_features)
>>> net.layer1.pipeline_stage = 0
>>> net.relu1.pipeline_stage = 0
>>> net.layer2.pipeline_stage = 0
>>> net.relu2.pipeline_stage = 1
>>> net.layer3.pipeline_stage = 1
>>> loss_fn = nn.CrossEntropyLoss()
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 2)
>>> net_with_loss.set_train()
>>> def forward_fn(inputs, target):
... loss = net_with_loss(inputs, target)
... return loss
>>>
>>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params())
>>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters)
>>>
>>> @ms.jit
>>> def train_one_step(inputs, target):
... loss, grads = grad_fn(inputs, target)
... grads = pp_grad_reducer(grads)
... optimizer(grads)
... return loss, grads
>>>
>>> parallel_net = AutoParallel(train_one_step, parallel_mode="semi_auto")
>>> parallel_net.pipeline(stages=2)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.ones([size, out_features]).astype(np.float32))
>>> loss, _ = train_one_step(inputs, label)
>>> print(loss)
46.36721
"""
def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
super(PipelineGradReducer, self).__init__(auto_prefix=False)
self._check_mode()
self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
self.grad_reducer = Identity()
self.degree = Tensor(1, mstype.float32)
self.scale_sense = Parameter(scale_sense, name='scale_sense')
self.hyper_map = C.HyperMap()
if opt_shard is None:
self.opt_shard = _get_enable_parallel_optimizer()
else:
self.opt_shard = opt_shard
@jit
def construct(self, grads):
new_grads = None
if self.opt_shard:
grads = self.grad_reducer(grads)
new_grads = self.hyper_map(F.partial(shard_grad_scale, self.scale_sense * self.degree),
grads, self.accu_grads)
else:
accu_grads = self.grad_reducer(self.accu_grads)
new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
return new_grads
def _check_mode(self):
"""check parallel mode"""
mode = context.get_context('mode')
if mode != context.GRAPH_MODE:
raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")