mindspore.parallel.parameter_broadcast
- mindspore.parallel.parameter_broadcast(net, layout, cur_rank=0, initial_rank=0)[源代码]
在数据并行维度,将参数广播到其他卡上。
警告
实验性接口,未来可能变更或移除。
- 参数:
net (Cell) - 将进行参数广播的网络。
layout (Dict) - 参数layout字典,来自于函数
mindspore.nn.Cell.parameter_layout_dict()
;也可以从从文件中读取。例如,通过AutoParallel.save_param_strategy_file(file_path)保存下来的strategy.ckpt。该字典的key是参数名称,value是参数的Layout。cur_rank (int,可选) - 当前的rankID。 默认值:
0
。initial_rank (int,可选) - 每个流水线并行阶段的起始rankID。 默认值:
0
。
- 异常:
ValueError - cur_rank 不是当前rank的rankID。
ValueError - initial_rank 不是当前流水线并行阶段的起始rankID。
ValueError - layout 中的参数名称不在函数
mindspore.nn.Cell.parameters_dict()
。
- 支持平台:
Ascend
样例:
>>> import os >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops >>> from mindspore.communication import init >>> from mindspore.common.initializer import initializer >>> from mindspore.train import Model >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> from mindspore.parallel.auto_parallel import AutoParallel >>> from mindspore.parallel import parameter_broadcast >>> ms.set_context(mode=ms.GRAPH_MODE) >>> ms.runtime.set_memory(max_size="28GB") >>> init() >>> ms.set_seed(1) >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.flatten = ops.Flatten() ... self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32)) ... self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32)) ... self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32)) ... self.matmul1 = ops.MatMul() ... self.relu1 = ops.ReLU() ... self.matmul2 = ops.MatMul() ... self.relu2 = ops.ReLU() ... self.matmul3 = ops.MatMul() ... def construct(self, x): ... x = self.flatten(x) ... x = self.matmul1(x, self.fc1_weight) ... x = self.relu1(x) ... x = self.matmul2(x, self.fc2_weight) ... x = self.relu2(x) ... logits = self.matmul3(x, self.fc3_weight) ... return logits >>> net = Network() >>> net.matmul1.shard(((2, 4), (4, 1))) >>> net.relu1.shard(((4, 1),)) >>> net.matmul2.shard(((1, 8), (8, 1))) >>> net.relu2.shard(((8, 1),)) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> optim = nn.SGD(net.trainable_params(), 1e-2) >>> loss = nn.CrossEntropyLoss() >>> parallel_net = AutoParallel(net) >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim) >>> model.train(1, dataset) >>> ms.save_checkpoint(net, "./simple.ckpt", False) >>> layout = model.train_network.parameter_layout_dict >>> param_dict = load_checkpoint("./simple.ckpt") >>> load_param_into_net(net, param_dict) >>> rank_id = os.environ["RANK_ID"] >>> parameter_broadcast(model.train_network, layout, int(rank_id), 0) >>> class LossCallBack(Callback): ... def step_end(self, run_context): ... cb_params = run_context.original_args() ... print("step end, cur step num: ", cb_params.cur_step_num, flush=True) >>> model.train(1, dataset, callbacks=[LossCallBack()])