mindspore.amp.build_train_network
- mindspore.amp.build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs)[source]
Build the mixed precision training cell automatically.
Note
After using custom_mixed_precision or auto_mixed_precision for precision conversion, it is not supported to perform the precision conversion again. If build_train_network is used to train a converted network, level need to be configured to
O0
to avoid the duplicated accuracy conversion.
- Parameters
network (Cell) – Definition of the network.
optimizer (
mindspore.nn.Optimizer
) – Define the optimizer to update the Parameter.loss_fn (Union[None, Cell]) – Define the loss function. If None, the network should have the loss inside. Default:
None
.level (str) –
Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default:
'O0'
.For details on amp level, refer to
mindspore.amp.auto_mixed_precision()
.Property of keep_batchnorm_fp32, cast_model_type and loss_scale_manager determined by level setting may be overwritten by settings in kwargs.
boost_level (str) –
Option for argument level in mindspore.boost , level for boost mode training. Supports ['O0', 'O1', 'O2']. Default:
'O0'
.'O0': Do not change.
'O1': Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy.
'O2': Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%.
If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
cast_model_type (
mindspore.dtype
) – Supports mstype.float16 or mstype.float32 . If set, the network will be casted to cast_model_type ( mstype.float16 or mstype.float32 ), but not to be casted to the type determined by level setting.keep_batchnorm_fp32 (bool) – Keep Batchnorm run in float32 when the network is set to cast to float16 . If set, the level setting will take no effect on this property.
loss_scale_manager (Union[None, LossScaleManager]) – If not None, must be subclass of
mindspore.amp.LossScaleManager
for scaling the loss. If set, the level setting will take no effect on this property.
- Raises
ValueError – If device is CPU, property loss_scale_manager is not None or
mindspore.amp.FixedLossScaleManager
(with property drop_overflow_update=False ).
Examples
>>> from mindspore import amp, nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) >>> amp_level="O3" >>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)