mindspore.parallel.nn.parallel_cell_wrapper 源代码

# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cell_wrapper."""
from __future__ import absolute_import
from __future__ import division

from mindspore import nn
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.nn.wrap.cell_wrapper import _MicroBatch

__all__ = ['PipelineCell', 'Pipeline', 'MicroBatchInterleaved', 'GradAccumulation']

class PipelineCell(Cell):
    """
    Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training.

    Note:
        micro_size must be greater or equal to pipeline stages.

    Args:
        network (Cell): The target network to wrap.
        micro_size (int): MicroBatch size.
        stage_config (dict, optional): The stage configuration for each cell's execution in pipeline parallel.

    Supported Platforms:
        ``Ascend``

    Examples:
        >>> import mindspore.nn as nn
        >>> # Define the network structure of LeNet5. Refer to
        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
        >>> net = LeNet5()
        >>> net = nn.PipelineCell(net, 4, stage_config={"cell_name_0": 0, "cell_name_1": 1})
    """
    def __init__(self, network, micro_size, stage_config=None):
        super(PipelineCell, self).__init__(auto_prefix=False)
        self.network = network
        self.micro_inputs = nn.CellList()
        self.micro_size = micro_size
        self.add_list = []
        if not isinstance(network, Cell):
            raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
                            "but got the type : {}.".format(type(network)))
        if not isinstance(micro_size, int):
            raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
                            "but got the type : {}.".format(type(micro_size)))
        if micro_size <= 0:
            raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
                             "but got {}.".format(micro_size))
        for i in range(micro_size):
            micro_input = _MicroBatch(micro_size)
            self.micro_inputs.append(micro_input)
            self.add = P.Add().add_prim_attr("pipeline_end", i)
            self.add_list.append(self.add)
        self._get_attr_from_cell(network)

        # prase stage_config
        config_dict = {}
        if stage_config is not None:
            for cell_name, stage_num in stage_config.items():
                config_cell_name = cell_name
                config_stage_num = stage_num
                config_dict[config_cell_name] = config_stage_num

        # set cell.stage_config
            for cell_name, cell in self.network.cells_and_names():
                for config_cell_name, config_stage_num in config_dict.copy().items():
                    if not cell_name or not config_cell_name:
                        continue
                    if cell_name == config_cell_name:
                        setattr(cell, "pipeline_stage", config_stage_num)
                        del config_dict[config_cell_name]

            for config_cell_name, config_stage_num in config_dict.copy().items():
                if str(network) == config_cell_name:
                    setattr(network, "pipeline_stage", config_stage_num)
                    del config_dict[config_cell_name]

            # if there are any config elements left, print them
            if config_dict:
                for config_cell_name, config_stage_num in config_dict.items():
                    print("pipeline_cell stage_config set pipeline_stage fail!")
                    print("config cell name:" + str(config_cell_name) +
                          " config stage num:" + str(config_stage_num))
                print("network:" + str(self.network))
                print("cell name available:")
                for cell_name, cell in self.network.cells_and_names():
                    print(cell_name)
                raise KeyError("For 'PipelineCell', the argument 'stage_config' : {} is not "
                               "found in 'network' : {}".format(config_dict, network))

    def construct(self, *inputs):
        ret = None
        for i in range(self.micro_size):
            micro_input = self.micro_inputs[i](i, *inputs)
            output = self.network(*micro_input)
            if ret is not None:
                ret = self.add_list[i](ret, output)
            else:
                ret = output
        return ret

[文档]class Pipeline(PipelineCell): """ Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training. Note: micro_size must be greater or equal to pipeline stages. Args: network (Cell): The target network to wrap. micro_size (int): MicroBatch size. stage_config (dict, optional): Stage configuration for cell's execution in pipeline parallel. Default ``None``. Supported Platforms: ``Ascend`` Examples: >>> from mindspore.parallel.nn import Pipeline >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = Pipeline(net, 4, stage_config={"cell_name_0": 0, "cell_name_1": 1}) """
[文档]class MicroBatchInterleaved(Cell): """ This function splits the input at the 0th into interleave_num pieces and then performs the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode and network, if the first slice data is calculating forward, the second slice data will execute the communication operators at the same time, to achieve the performance acceleration of communication and computing concurrency. Args: network (Cell): The target network to wrap. interleave_num (int, optional): split num of batch size. Default: ``2`` . Inputs: tuple[Tensor]. It's the same with the input of the `network` . Outputs: The wrapped input. The output of the input `network` should be a Tensor. Supported Platforms: ``Ascend`` Examples: >>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = nn.MicroBatchInterleaved(net, 2) """ def __init__(self, network, interleave_num=2): super(MicroBatchInterleaved, self).__init__(auto_prefix=False) if not isinstance(interleave_num, int): raise TypeError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be integer, " "but got the type : {}.".format(type(interleave_num))) if interleave_num <= 0: raise ValueError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be large than 0, " "but got {}.".format(interleave_num)) self.network = network self.interleave_num = interleave_num self.interleave_inputs = nn.CellList() self.add = P.Add().add_prim_attr("micro_interleaved_add_flag", True) for _ in range(interleave_num): interleave_data = _MicroBatch(interleave_num) interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True) interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num) self.interleave_inputs.append(interleave_data) self._get_attr_from_cell(network) def construct(self, *inputs): output = 0.0 for i in range(self.interleave_num): interleave_input = self.interleave_inputs[i](i, *inputs) output = self.add(output, self.network(*interleave_input)) return output
[文档]class GradAccumulation(Cell): """ Wrap the network with Micro Batch to enable the grad accumulation. Args: network (Cell): The target network to wrap. micro_size (int): MicroBatch size. Supported Platforms: ``Ascend`` Examples: >>> from mindspore.parallel.nn import GradAccumulation >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = GradAccumulation(net, 4) """ def __init__(self, network, micro_size): super(GradAccumulation, self).__init__(auto_prefix=False) self.network = network self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] if not isinstance(network, Cell): raise TypeError("For 'GradAccumulation', the argument 'network' must cell type, " "but got the type : {}.".format(type(network))) if not isinstance(micro_size, int): raise TypeError("For 'GradAccumulation', the argument 'micro_size' must be integer, " "but got the type : {}.".format(type(micro_size))) if micro_size <= 0: raise ValueError("For 'GradAccumulation', the argument 'micro_size' must be large than 0, " "but got {}.".format(micro_size)) for i in range(micro_size): micro_input = _MicroBatch(micro_size) micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size) self.micro_inputs.append(micro_input) self.add = P.Add().add_prim_attr("forward_end", i) self.add_list.append(self.add) self._get_attr_from_cell(network) def construct(self, *inputs): ret = None for i in range(self.micro_size): micro_input = self.micro_inputs[i](i, *inputs) output = self.network(*micro_input) if ret is not None: ret = self.add_list[i](ret, output) else: ret = output return ret