mindspore.nn.WithGradCell

class mindspore.nn.WithGradCell(network, loss_fn=None, sens=None)[源代码]

Cell that returns the gradients.

Wraps the network with backward cell to compute gradients. A network with a loss function is necessary as argument. If loss function in None, the network must be a wrapper of network and loss function. This Cell accepts ‘*inputs’ as inputs and returns gradients for each trainable parameter.

说明

Run in PyNative mode.

参数
  • network (Cell) – The target network to wrap. The network only supports single output.

  • loss_fn (Cell) – Primitive loss function used to compute gradients. Default: None.

  • sens (Union[None, Tensor, Scalar, Tuple ...]) – The sensitive for backpropagation, the type and shape must be same as the network output. If None, we will fill one to a same type shape of output value. Default: None.

Inputs:
  • (*inputs) (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

list, a list of Tensors with identical shapes as trainable weights.

异常

TypeError – If sens is not one of None, Tensor, Scalar or Tuple.

Supported Platforms:

Ascend GPU CPU

样例

>>> # For a defined network Net without loss function
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> grad_net = nn.WithGradCell(net, loss_fn)
>>>
>>> # For a network wrapped with loss function
>>> net = Net()
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
>>> grad_net = nn.WithGradCell(net_with_criterion)