mindspore.sync_pipeline_shared_parameters

View Source On Gitee
mindspore.sync_pipeline_shared_parameters(net)[source]

synchronize pipeline parallel stage shared parameters. Parameters may be shared between different stages. For example, embedding table is shared by WordEmbedding layer and LMHead layer, which are usually split into different stages. It is necessary to perform synchronization after embedding table changes.

Note

The network should be compiled before synchronize pipeline parallel stage shared parameters.

Parameters

net (nn.Cell) – the inference network.

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, ops, Parameter, Tensor
>>> class VocabEmbedding(nn.Cell):
...     def __init__(self, vocab_size, embedding_size):
...         super().__init__()
...         self.embedding_table = Parameter(Tensor(np.ones([vocab_size, embedding_size]), ms.float32),
...                                          name='embedding')
...         self.gather = ops.Gather()
...
...     def construct(self, x):
...         output = self.gather(self.embedding_table, x, 0)
...         output = output.squeeze(1)
...         return output, self.embedding_table.value()
...
>>> class LMHead(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.matmul = ops.MatMul(transpose_b=True)
...
...     def construct(self, state, embed):
...         return self.matmul(state, embed)
...
>>> class Network(nn.Cell):
...     @lazy_inline
...     def __init__(self):
...         super().__init__()
...         self.word_embedding = VocabEmbedding(vocab_size=4, embedding_size=4)
...         self.head = LMHead()
...
...     def construct(self, x):
...         x, embed = self.word_embedding(x)
...         x = self.head(x, embed)
...         return x
>>>
>>> net = Network()
>>> net.word_embedding.pipeline_stage = 0
>>> net.head.pipeline_stage = 1
>>> x = Tensor(np.ones((8, 4))
>>> net.compile()
>>> ms.sync_pipeline_shared_parameters(net)
>>> print(net.word_embedding.embedding_table.asnumpy())
>>> [[1. 1. 1. 1.]
     [1. 1. 1. 1.]
     [1. 1. 1. 1.]
     [1. 1. 1. 1.]]