mindspore.amp.build_train_network

View Source On Gitee
mindspore.amp.build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs)[source]

Build the mixed precision training cell automatically.

Note

  • After using custom_mixed_precision or auto_mixed_precision for precision conversion, it is not supported to perform the precision conversion again. If build_train_network is used to train a converted network, level need to be configured to O0 to avoid the duplicated accuracy conversion.

Parameters
  • network (Cell) – Definition of the network.

  • optimizer (mindspore.nn.Optimizer) – Define the optimizer to update the Parameter.

  • loss_fn (Union[None, Cell]) – Define the loss function. If None, the network should have the loss inside. Default: None .

  • level (str) –

    Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: 'O0' .

    For details on amp level, refer to mindspore.amp.auto_mixed_precision().

    Property of keep_batchnorm_fp32, cast_model_type and loss_scale_manager determined by level setting may be overwritten by settings in kwargs.

  • boost_level (str) –

    Option for argument level in mindspore.boost , level for boost mode training. Supports ['O0', 'O1', 'O2']. Default: 'O0' .

    • 'O0': Do not change.

    • 'O1': Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy.

    • 'O2': Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%.

    If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.

  • cast_model_type (mindspore.dtype) – Supports mstype.float16 or mstype.float32 . If set, the network will be casted to cast_model_type ( mstype.float16 or mstype.float32 ), but not to be casted to the type determined by level setting.

  • keep_batchnorm_fp32 (bool) – Keep Batchnorm run in float32 when the network is set to cast to float16 . If set, the level setting will take no effect on this property.

  • loss_scale_manager (Union[None, LossScaleManager]) – If not None, must be subclass of mindspore.amp.LossScaleManager for scaling the loss. If set, the level setting will take no effect on this property.

Raises

ValueError – If device is CPU, property loss_scale_manager is not None or mindspore.amp.FixedLossScaleManager (with property drop_overflow_update=False ).

Examples

>>> from mindspore import amp, nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> amp_level="O3"
>>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)