mindspore.nn.ParameterUpdate
- class mindspore.nn.ParameterUpdate(param)[source]
Cell that updates parameter.
With this Cell, one can manually update param with the input Tensor.
- Parameters
param (Parameter) – The parameter to be updated manually.
- Inputs:
x (Tensor) - A tensor whose shape and type are the same with param.
- Outputs:
Tensor, the input x.
- Raises
KeyError – If parameter with the specified name does not exist.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> network = nn.Dense(3, 4) >>> param = network.parameters_dict()['weight'] >>> update = nn.ParameterUpdate(param) >>> update.phase = "update_param" >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) >>> output = update(weight)