mindspore.nn.ParameterUpdate
- class mindspore.nn.ParameterUpdate(param)[源代码]
更新参数的Cell。
使用输入的 Tensor 值更新 param 的值。
参数:
param (Parameter) - 输入的参数。
输入:
x (Tensor)- shape和type与 param 相同的Tensor。
输出:
Tensor,输入 x。
异常:
KeyError - 指定名称的参数不存在。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> import mindspore >>> from mindspore import nn, Tensor >>> network = nn.Dense(3, 4) >>> param = network.parameters_dict()['weight'] >>> update = nn.ParameterUpdate(param) >>> update.phase = "update_param" >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) >>> output = update(weight) >>> print(output) [[ 0. 1. 2.] [ 3. 4. 5.] [ 6. 7. 8.] [ 9. 10. 11.]]