Source code for mindspore.parallel.parameter_broadcast

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameter broadcast"""
from __future__ import absolute_import

__all__ = ["parameter_broadcast"]

import numpy as np
import mindspore as ms
from mindspore.communication import get_rank, create_group, get_group_size


[docs]def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0): """ Broadcast parameter to other rank in data parallel dimension. .. warning:: This is an experimental API that is subject to change or deletion. Args: net (Cell): The network where the parameters will be broadcasted. layout (Dict): Parameter layout dictionary. Come from :func:`mindspore.nn.Cell.parameter_layout_dict` or read from file(for example: "strategy.ckpt" saved by using the `strategy_ckpt_config` parameter of :func:`mindspore.set_auto_parallel_context`). The key is param name, the value is the layout of this parameter. cur_rank (int, optional): current rank id. Default: ``0``. initial_rank (int, optional): Start rank id for each pipeline. Default: ``0``. Raises: ValueError: `cur_rank` is not rank id of current rank. ValueError: `initial_rank` is not the start rank id of current pipeline stage. ValueError: Parameter name in `layout` can not be found in :func:`mindspore.nn.Cell.parameters_dict`. Examples: >>> import os >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops >>> from mindspore.communication import init >>> from mindspore.common.initializer import initializer >>> from mindspore.train import Model >>> from mindspore.parallel.parameter_broadcast import parameter_broadcast >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> ms.set_context(mode=ms.GRAPH_MODE) >>> ms.set_context(max_device_memory="28GB") >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL) >>> init() >>> ms.set_seed(1) >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = ops.Flatten() ... self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32)) ... self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32)) ... self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32)) ... self.matmul1 = ops.MatMul() ... self.relu1 = ops.ReLU() ... self.matmul2 = ops.MatMul() ... self.relu2 = ops.ReLU() ... self.matmul3 = ops.MatMul() ... def construct(self, x): ... x = self.flatten(x) ... x = self.matmul1(x, self.fc1_weight) ... x = self.relu1(x) ... x = self.matmul2(x, self.fc2_weight) ... x = self.relu2(x) ... logits = self.matmul3(x, self.fc3_weight) ... return logits >>> net = Network() >>> net.matmul1.shard(((2, 4), (4, 1))) >>> net.relu1.shard(((4, 1),)) >>> net.matmul2.shard(((1, 8), (8, 1))) >>> net.relu2.shard(((8, 1),)) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.10/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> optim = nn.SGD(net.trainable_params(), 1e-2) >>> loss = nn.CrossEntropyLoss() >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> model.train(1, dataset) >>> ms.save_checkpoint(net, "./simple.ckpt", False) >>> layout = model.train_network.parameter_layout_dict >>> param_dict = load_checkpoint("./simple.ckpt") >>> load_param_into_net(net, param_dict) >>> rank_id = os.environ["RANK_ID"] >>> parameter_broadcast(model.train_network, layout, int(rank_id), 0) >>> class LossCallBack(Callback): ... def step_end(self, run_context): ... cb_params = run_context.original_args() ... print("step end, cur step num: ", cb_params.cur_step_num, flush=True) >>> model.train(1, dataset, callbacks=[LossCallBack()]) """ if not layout: return from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy from mindspore.nn.wrap.cell_wrapper import AllreduceGraph origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode") if origin_parallel_mode not in ("semi_auto_parallel", "auto_parallel"): return if cur_rank != get_rank(): raise ValueError(f"For parameter broadcast, the cur_rank: {cur_rank} is wrong.") if initial_rank % (get_group_size() / ms.get_auto_parallel_context("pipeline_stages")) != 0: raise ValueError(f"For parameter broadcast, the initial_rank: {initial_rank} is wrong.") param_redundancy = get_parameter_redundancy(layout, initial_rank) if not param_redundancy: return single_params = remove_param_redundancy(param_redundancy) if not single_params: return param_redundancy_reversed = {} for key, redundancy in param_redundancy.items(): for item in redundancy: if len(item) == 1: continue if cur_rank in item: param_redundancy_reversed.setdefault(item, []).append(key) if not param_redundancy_reversed: return if cur_rank not in single_params: return net_param_dict = net.parameters_dict() ms.set_auto_parallel_context(parallel_mode="hybrid_parallel") for group, params in param_redundancy_reversed.items(): create_group(str(group), list(group)) allreduce_input = [] for param in params: if param not in net_param_dict: raise ValueError(f"For parameter broadcast, the param: {param} can not be found.") real_param = net_param_dict[param] if param not in single_params[cur_rank]: real_param.set_data(ms.Tensor(np.zeros(real_param.shape), dtype=real_param.dtype)) allreduce_input.append(real_param) if not allreduce_input: continue allreduce_graph = AllreduceGraph(allreduce_input, str(group)) allreduce_graph() ms.set_auto_parallel_context(parallel_mode=origin_parallel_mode)