mindspore.parameter_broadcast
- mindspore.parameter_broadcast(net, layout, cur_rank=0, initial_rank=0)[源代码]
在数据并行维度将参数广播给另外的卡。
警告
这是一个实验性API,后续可能修改或删除。
- 参数:
net (Cell) - 参数将被广播的网络。
layout (Dict) - 参数排布字典。 来自
mindspore.nn.Cell.parameter_layout_dict()
或 从文件中读取(如: 通过mindspore.set_auto_parallel_context()
接口的 strategy_ckpt_config 参数保存的"strategy.ckpt"文件)。key为参数名, value为该参数的layout。cur_rank (int,可选) - 当前卡的rank id。默认值:
0
。initial_rank (int,可选) - 当前流水线并行stage起始rank id。默认值:
0
。
- 异常:
ValueError - cur_rank 不是当前卡的rank_id。
ValueError - initial_rank 不是当前pipeline_stage起始的rank_id。
ValueError - layout 中的参数名在
mindspore.nn.Cell.parameters_dict()
中找不到。
样例:
>>> import os >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops >>> from mindspore.communication import init >>> from mindspore.common.initializer import initializer >>> from mindspore.train import Model >>> from mindspore.parallel.parameter_broadcast import parameter_broadcast >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> ms.set_context(mode=ms.GRAPH_MODE) >>> ms.set_context(max_device_memory="28GB") >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL) >>> init() >>> ms.set_seed(1) >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = ops.Flatten() ... self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32)) ... self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32)) ... self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32)) ... self.matmul1 = ops.MatMul() ... self.relu1 = ops.ReLU() ... self.matmul2 = ops.MatMul() ... self.relu2 = ops.ReLU() ... self.matmul3 = ops.MatMul() ... def construct(self, x): ... x = self.flatten(x) ... x = self.matmul1(x, self.fc1_weight) ... x = self.relu1(x) ... x = self.matmul2(x, self.fc2_weight) ... x = self.relu2(x) ... logits = self.matmul3(x, self.fc3_weight) ... return logits >>> net = Network() >>> net.matmul1.shard(((2, 4), (4, 1))) >>> net.relu1.shard(((4, 1),)) >>> net.matmul2.shard(((1, 8), (8, 1))) >>> net.relu2.shard(((8, 1),)) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> optim = nn.SGD(net.trainable_params(), 1e-2) >>> loss = nn.CrossEntropyLoss() >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> model.train(1, dataset) >>> ms.save_checkpoint(net, "./simple.ckpt", False) >>> layout = model.train_network.parameter_layout_dict >>> param_dict = load_checkpoint("./simple.ckpt") >>> load_param_into_net(net, param_dict) >>> rank_id = os.environ["RANK_ID"] >>> parameter_broadcast(model.train_network, layout, int(rank_id), 0) >>> class LossCallBack(Callback): ... def step_end(self, run_context): ... cb_params = run_context.original_args() ... print("step end, cur step num: ", cb_params.cur_step_num, flush=True) >>> model.train(1, dataset, callbacks=[LossCallBack()])