sciai.common.TrainStepCell
- class sciai.common.TrainStepCell(network, optimizer, grad_first=False, clip_grad=False, clip_norm=0.001)[source]
Cell with gradient descent, similar to nn.TrainOneStepCell, but can accept multi-losses return.
- Parameters
network (Cell) – The training network. The network supports multi-outputs.
optimizer (Union[Cell]) – Optimizer for updating the network parameters.
grad_first (bool) – If True, only the first output of the network would participate in the gradient descent. Otherwise, the sum of all outputs of the network would be taken into account. Default: False.
clip_grad (bool) – Whether clip grad or not. Default: False.
clip_norm (Union[float, int]) – The clipping ratio, it should be greater than 0. Only enabled when clip_grad is True. Default: 1e-3.
- Inputs:
*inputs (tuple[Tensor]) - Tuple of input tensors with shape \((N, \ldots)\).
- Outputs:
Union(Tensor, tuple[Tensor]), tensor(s) of the loss value(s), the shape of which is(are) usually \(()\).
- Raises
TypeError – If network or optimizer is not of correct type.
- Supported Platforms:
Ascend
GPU
CPU