mindspore.nn.ParameterUpdate

class mindspore.nn.ParameterUpdate(param)[源代码]

更新参数的Cell。

使用输入的 Tensor 值更新 param 的值。

参数:
  • param (Parameter) - 输入的参数。

输入:
  • x (Tensor)- shape和type与 param 相同的Tensor。

输出:

Tensor,输入 x

异常:
  • KeyError - 指定名称的参数不存在。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import nn, Tensor
>>> network = nn.Dense(3, 4)
>>> param = network.parameters_dict()['weight']
>>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param"
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
>>> output = update(weight)
>>> print(output)
[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]
 [ 9. 10. 11.]]