mindspore.parameter_broadcast
- mindspore.parameter_broadcast(net, layout, cur_rank=0, initial_rank=0)[source]
Broadcast parameter to other rank in data parallel dimension.
Warning
This is an experimental API that is subject to change or deletion.
- Parameters
net (Cell) – The network where the parameters will be broadcasted.
layout (Dict) – Parameter layout dictionary. Come from
mindspore.nn.Cell.parameter_layout_dict()
or read from file(for example: "strategy.ckpt" saved by using the strategy_ckpt_config parameter ofmindspore.set_auto_parallel_context()
). The key is param name, the value is the layout of this parameter.cur_rank (int, optional) – current rank id. Default:
0
.initial_rank (int, optional) – Start rank id for each pipeline. Default:
0
.
- Raises
ValueError – cur_rank is not rank id of current rank.
ValueError – initial_rank is not the start rank id of current pipeline stage.
ValueError – Parameter name in layout can not be found in
mindspore.nn.Cell.parameters_dict()
.
Examples
>>> import os >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops >>> from mindspore.communication import init >>> from mindspore.common.initializer import initializer >>> from mindspore.train import Model >>> from mindspore.parallel.parameter_broadcast import parameter_broadcast >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> ms.set_context(mode=ms.GRAPH_MODE) >>> ms.set_context(max_device_memory="28GB") >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL) >>> init() >>> ms.set_seed(1) >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = ops.Flatten() ... self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32)) ... self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32)) ... self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32)) ... self.matmul1 = ops.MatMul() ... self.relu1 = ops.ReLU() ... self.matmul2 = ops.MatMul() ... self.relu2 = ops.ReLU() ... self.matmul3 = ops.MatMul() ... def construct(self, x): ... x = self.flatten(x) ... x = self.matmul1(x, self.fc1_weight) ... x = self.relu1(x) ... x = self.matmul2(x, self.fc2_weight) ... x = self.relu2(x) ... logits = self.matmul3(x, self.fc3_weight) ... return logits >>> net = Network() >>> net.matmul1.shard(((2, 4), (4, 1))) >>> net.relu1.shard(((4, 1),)) >>> net.matmul2.shard(((1, 8), (8, 1))) >>> net.relu2.shard(((8, 1),)) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> optim = nn.SGD(net.trainable_params(), 1e-2) >>> loss = nn.CrossEntropyLoss() >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> model.train(1, dataset) >>> ms.save_checkpoint(net, "./simple.ckpt", False) >>> layout = model.train_network.parameter_layout_dict >>> param_dict = load_checkpoint("./simple.ckpt") >>> load_param_into_net(net, param_dict) >>> rank_id = os.environ["RANK_ID"] >>> parameter_broadcast(model.train_network, layout, int(rank_id), 0) >>> class LossCallBack(Callback): ... def step_end(self, run_context): ... cb_params = run_context.original_args() ... print("step end, cur step num: ", cb_params.cur_step_num, flush=True) >>> model.train(1, dataset, callbacks=[LossCallBack()])