mindspore.merge_sliced_parameter
- mindspore.merge_sliced_parameter(sliced_parameters, strategy=None)[源代码]
将参数切片合并为一个完整的参数,用于分布式推理。关于它的细节,请参考:保存和加载模型(HyBrid Parallel模式)。
- 参数:
sliced_parameters (list[Parameter]) - 参数切片,按rank id进行排列。
strategy (Optional[dict]) - 参数切片策略,key为参数名称,value为该参数的切片策略。如果 strategy 为None,则只需按0轴顺序合并参数切片。默认值:None。
- 返回:
合并后的参数,包含所有数据。
- 异常:
ValueError - 合并失败。
TypeError - sliced_parameters 不正确或 strategy 不是dict。
KeyError - 参数名称不在策略的key中。
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, Parameter >>> >>> sliced_parameters = [ ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), ... "network.embedding_table"), ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), ... "network.embedding_table"), ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), ... "network.embedding_table"), ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), ... "network.embedding_table")] >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters) >>> print(merged_parameter) Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)