mindspore.amp.auto_mixed_precision

mindspore.amp.auto_mixed_precision(network, amp_level='O0')[source]

auto mixed precision function.

Parameters
  • network (Cell) – Definition of the network.

  • amp_level (str) –

    Supports [“O0”, “O1”, “O2”, “O3”]. Default: “O0”.

    • ”O0”: Do not change.

    • ”O1”: [DEMO] Cast the operators in white_list to float16, the remaining operators are kept in float32.

    • ”O2”: Cast network to float16, keep operators in black_list run in float32,

    • ”O3”: Cast network to float16.

Raises

ValueError – If amp level is not supported.

Examples

>>> from mindpsore import amp, nn
>>> network = LeNet5()
>>> amp_level = "O1"
>>> net = amp.auto_mixed_precision(network, amp_level)