mindspore.parameter_broadcast

View Source On Gitee
mindspore.parameter_broadcast(net, layout, cur_rank=0, initial_rank=0)[source]

Broadcast parameter to other rank in data parallel dimension.

Warning

This is an experimental API that is subject to change or deletion.

Parameters
  • net (Cell) – The network where the parameters will be broadcasted.

  • layout (Dict) – Parameter layout dictionary. Come from mindspore.nn.Cell.parameter_layout_dict() or read from file(for example: "strategy.ckpt" saved by using the strategy_ckpt_config parameter of mindspore.set_auto_parallel_context()). The key is param name, the value is the layout of this parameter.

  • cur_rank (int, optional) – current rank id. Default: 0.

  • initial_rank (int, optional) – Start rank id for each pipeline. Default: 0.

Raises

Examples

>>> import os
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer
>>> from mindspore.train import Model
>>> from mindspore.parallel.parameter_broadcast import parameter_broadcast
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.set_context(max_device_memory="28GB")
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
>>> init()
>>> ms.set_seed(1)
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = ops.Flatten()
...         self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32))
...         self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32))
...         self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32))
...         self.matmul1 = ops.MatMul()
...         self.relu1 = ops.ReLU()
...         self.matmul2 = ops.MatMul()
...         self.relu2 = ops.ReLU()
...         self.matmul3 = ops.MatMul()
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.matmul1(x, self.fc1_weight)
...         x = self.relu1(x)
...         x = self.matmul2(x, self.fc2_weight)
...         x = self.relu2(x)
...         logits = self.matmul3(x, self.fc3_weight)
...         return logits
>>> net = Network()
>>> net.matmul1.shard(((2, 4), (4, 1)))
>>> net.relu1.shard(((4, 1),))
>>> net.matmul2.shard(((1, 8), (8, 1)))
>>> net.relu2.shard(((8, 1),))
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> optim = nn.SGD(net.trainable_params(), 1e-2)
>>> loss = nn.CrossEntropyLoss()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model.train(1, dataset)
>>> ms.save_checkpoint(net, "./simple.ckpt", False)
>>> layout = model.train_network.parameter_layout_dict
>>> param_dict = load_checkpoint("./simple.ckpt")
>>> load_param_into_net(net, param_dict)
>>> rank_id = os.environ["RANK_ID"]
>>> parameter_broadcast(model.train_network, layout, int(rank_id), 0)
>>> class LossCallBack(Callback):
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         print("step end, cur step num: ", cb_params.cur_step_num, flush=True)
>>> model.train(1, dataset, callbacks=[LossCallBack()])