mindspore.nn.ParameterUpdate

class mindspore.nn.ParameterUpdate(param)[source]

Cell that updates parameter.

With this Cell, one can manually update param with the input Tensor.

Parameters

param (Parameter) – The parameter to be updated manually.

Inputs:
  • x (Tensor) - A tensor whose shape and type are the same with param.

Outputs:

Tensor, the input x.

Raises

KeyError – If parameter with the specified name does not exist.

Supported Platforms:

Ascend GPU CPU

Examples

>>> network = nn.Dense(3, 4)
>>> param = network.parameters_dict()['weight']
>>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param"
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
>>> output = update(weight)