mindspore.amp.build_train_network
- mindspore.amp.build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs)[source]
Build the mixed precision training cell automatically.
Note
After using custom_mixed_precision or auto_mixed_precision for precision conversion, it is not supported to perform the precision conversion again. If build_train_network is used to train a converted network, level need to be configured to
O0
to avoid the duplicated accuracy conversion.
- Parameters
network (Cell) – Definition of the network.
optimizer (
mindspore.nn.Optimizer
) – Define the optimizer to update the Parameter.loss_fn (Union[None, Cell]) – Define the loss function. If None, the network should have the loss inside. Default:
None
.level (str) –
Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default:
'O0'
.'O0': Do not change.
'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32. The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose, Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
'O2': Cast network to float16, keep mindspore.nn.BatchNorm series interface,
mindspore.nn.LayerNorm
and loss_fn (if set) run in float32, using dynamic loss scale.'O3': Cast network to float16, with additional property keep_batchnorm_fp32=False .
'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all scenarios. User should specify the level for special network.
'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of keep_batchnorm_fp32, cast_model_type and loss_scale_manager determined by level setting may be overwritten by settings in kwargs.
boost_level (str) –
Option for argument level in mindspore.boost , level for boost mode training. Supports ['O0', 'O1', 'O2']. Default:
'O0'
.'O0': Do not change.
'O1': Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy.
'O2': Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%.
If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
cast_model_type (
mindspore.dtype
) – Supports mstype.float16 or mstype.float32 . If set, the network will be casted to cast_model_type ( mstype.float16 or mstype.float32 ), but not to be casted to the type determined by level setting.keep_batchnorm_fp32 (bool) – Keep Batchnorm run in float32 when the network is set to cast to float16 . If set, the level setting will take no effect on this property.
loss_scale_manager (Union[None, LossScaleManager]) – If not None, must be subclass of
mindspore.amp.LossScaleManager
for scaling the loss. If set, the level setting will take no effect on this property.
- Raises
ValueError – If device is CPU, property loss_scale_manager is not None or
mindspore.amp.FixedLossScaleManager
(with property drop_overflow_update=False ).
Examples
>>> from mindspore import amp, nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) >>> amp_level="O3" >>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)