mindspore.sync_pipeline_shared_parameters
- mindspore.sync_pipeline_shared_parameters(net)[源代码]
在流水线并行场景下,部分参数可能会被不同的stage之间共享。例如 embedding table 被 VocabEmbedding 和 LMHead 两层共享,这两层通常会被切分到不同的stage上。 在流水线并行推理时,有必要 embedding table 变更后在stage之间进行权重同步。
说明
网络需要先编译,再执行stage之间权重同步。
- 参数:
net (nn.Cell) - 推理网络。
- 支持平台:
Ascend
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> import mindspore as ms >>> import mindspore.communication.management as D >>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor >>> context.set_context(mode=context.GRAPH_MODE) >>> class Embedding(nn.Cell): ... def __init__(self, shape): ... super().__init__() ... self.w = Parameter(Tensor(np.ones(shape), ms.float32), name='w') ... self.matmul = ops.MatMul().shard(((1, 1), (1, 1))) ... def construct(self, x): ... return self.matmul(x, self.w), self.w ... >>> class LMHead(nn.Cell): ... def __init__(self): ... super().__init__() ... self.matmul = ops.MatMul(transpose_b=True).shard(((1, 1), (1, 1))) ... def construct(self, x, w): ... return self.matmul(x, w) ... >>> class Network(nn.Cell): ... @lazy_inline ... def __init__(self): ... super().__init__() ... shape = (4, 4) ... self.word_embedding = Embedding(shape) ... self.lm_head = LMHead() ... self.word_embedding.pipeline_stage = 0 ... self.lm_head.pipeline_stage = 1 ... def construct(self, x): ... x, embed = self.word_embedding(x) ... return self.lm_head(x, embed) ... >>> class PipelineCellInference(nn.Cell): ... def __init__(self, network, micro_batch_num): ... super().__init__() ... self.network = network ... self.micro_batch_num = micro_batch_num ... self.concat = ops.Concat() ... def construct(self, x): ... ret = () ... for i in range(self.micro_batch_num): ... micro_batch_size = x.shape[0] // self.micro_batch_num ... start = micro_batch_size * i ... end = micro_batch_size * (i + 1) ... micro_input = x[start:end] ... y = self.network(micro_input) ... ret = ret + (y,) ... ret = self.concat(ret) ... return ret >>> D.init() >>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2) >>> net = Network() >>> net = PipelineCellInference(net, 2) >>> net.set_train(False) >>> x = Tensor(np.ones((2, 4)), ms.float32) >>> net.compile(x) >>> ms.sync_pipeline_shared_parameters(net) >>> print(net.network.word_embedding.w.asnumpy()) [[1. 1. 1. 1.] [1. 1. 1. 1.] [1. 1. 1. 1.] [1. 1. 1. 1.]]