mindspore.nn.WithGradCell
- class mindspore.nn.WithGradCell(network, loss_fn=None, sens=None)[源代码]
Cell that returns the gradients.
Wraps the network with backward cell to compute gradients. A network with a loss function is necessary as argument. If loss function in None, the network must be a wrapper of network and loss function. This Cell accepts ‘*inputs’ as inputs and returns gradients for each trainable parameter.
Note
Run in PyNative mode.
- Parameters
network (Cell) – The target network to wrap. The network only supports single output.
loss_fn (Cell) – Primitive loss function used to compute gradients. Default: None.
sens (Union[None, Tensor, Scalar, Tuple ...]) – The sensitive for backpropagation, the type and shape must be same as the network output. If None, we will fill one to a same type shape of output value. Default: None.
- Inputs:
(*inputs) (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).
- Outputs:
list, a list of Tensors with identical shapes as trainable weights.
- Raises
TypeError – If sens is not one of None, Tensor, Scalar or Tuple.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> # For a defined network Net without loss function >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> grad_net = nn.WithGradCell(net, loss_fn) >>> >>> # For a network wrapped with loss function >>> net = Net() >>> net_with_criterion = nn.WithLossCell(net, loss_fn) >>> grad_net = nn.WithGradCell(net_with_criterion)