sciai.common.TrainStepCell

View Source On Gitee
class sciai.common.TrainStepCell(network, optimizer, grad_first=False, clip_grad=False, clip_norm=0.001)[source]

Cell with gradient descent, similar to nn.TrainOneStepCell, but can accept multi-losses return.

Parameters
  • network (Cell) – The training network. The network supports multi-outputs.

  • optimizer (Union[Cell]) – Optimizer for updating the network parameters.

  • grad_first (bool) – If True, only the first output of the network would participate in the gradient descent. Otherwise, the sum of all outputs of the network would be taken into account. Default: False.

  • clip_grad (bool) – Whether clip grad or not. Default: False.

  • clip_norm (Union[float, int]) – The clipping ratio, it should be greater than 0. Only enabled when clip_grad is True. Default: 1e-3.

Inputs:
  • *inputs (tuple[Tensor]) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

Union(Tensor, tuple[Tensor]), tensor(s) of the loss value(s), the shape of which is(are) usually \(()\).

Raises

TypeError – If network or optimizer is not of correct type.

Supported Platforms:

Ascend GPU CPU