mindspore.amp.custom_mixed_precision

View Source On Gitee
mindspore.amp.custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16)[source]

Custom mixed precision by setting whitelist or blacklist. When the white_list is provided, primitives and cells in white_list will perform the precision conversion. When the black_list is provided, cells that are not in black_list will perform the pereision conversion. Only one of white_list and black_list should be provided.

Note

  • Repeatedly calling mixed-precision interfaces, such as custom_mixed_precision and auto_mixed_precision, can result in a larger network hierarchy and slower performance.

  • If interfaces like Model and build_train_network is used to train the network which is converted by mixed-precision interfaces such as custom_mixed_precision and auto_mixed_precision, amp_level need to be configured to O0 to avoid the duplicated accuracy conversion.

  • Primitives for blacklist is not support yet.

Parameters
  • network (Cell) – Definition of the network.

  • white_list (list[Primitive, Cell], optional) – White list of custom mixed precision. Defaults: None , means white list is not used.

  • black_list (list[Cell], optional) – Black list of custom mixed precision. Defaults: None , means black list is not used.

  • dtype (Type) – The type used in lower precision calculations, can be mstype.float16 or mstype.bfloat16 , default: mstype.float16 .

Returns

network (Cell), A network supporting mixed precision.

Raises
  • TypeError – The network type is not Cell.

  • ValueError – Neither white_list nor black_list is provided.

  • ValueError – If dtype is not one of mstype.float16 , mstype.bfloat16 .

  • ValueError – Both white_list and black_list are provided.

Examples

>>> from mindspore import amp, nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> custom_white_list = amp.get_white_list()
>>> custom_white_list.append(nn.Flatten)
>>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)