# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto mixed precision."""
from __future__ import absolute_import
import inspect
import types
from typing import Any
import functools
import collections
import mindspore as ms
from mindspore import nn
from mindspore import _checkparam as validator
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_pipeline_stages
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from mindspore import boost, context
from mindspore.ops import operations as P
from mindspore.ops import Primitive
from mindspore.ops import auto_generate as gen
from mindspore import log as logger
from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel
AMP_WHITE_LIST = [
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Conv1dTranspose,
nn.Conv2dTranspose,
nn.Conv3dTranspose,
nn.Dense,
nn.LSTMCell,
nn.RNNCell,
nn.GRUCell,
P.Conv2D,
P.Conv3D,
P.Conv2DTranspose,
P.Conv3DTranspose,
P.Conv2DBackpropInput,
P.MatMul,
P.BatchMatMul,
P.PReLU,
P.ReLU,
P.Ger,
]
AMP_BLACK_LIST = [
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
]
AMP_AUTO_WHITE_LIST = [
P.Conv2D,
P.Conv3D,
P.Conv2DTranspose,
P.Conv3DTranspose,
gen.Convolution,
P.MatMul,
gen.MatMulExt,
P.BatchMatMul,
gen.BatchMatMulExt,
gen.PReLU,
P.Einsum,
gen.Dense,
gen.Addmm,
]
AMP_AUTO_BLACK_LIST = [
gen.Pow,
gen.ACos,
gen.Asin,
gen.Cosh,
P.Erfinv,
P.Exp,
P.Expm1,
P.Log,
P.Log1p,
P.Reciprocal,
P.Rsqrt,
P.Sinh,
P.Tan,
P.Softplus,
gen.SoftplusExt,
P.LayerNorm,
gen.LayerNormExt,
P.BatchNorm,
gen.GroupNorm,
P.KLDivLoss,
P.SmoothL1Loss,
P.MultilabelMarginLoss,
P.SoftMarginLoss,
P.TripletMarginLoss,
P.MultiMarginLoss,
P.BCEWithLogitsLoss,
P.Pdist,
P.Cdist,
P.Renorm,
gen.MSELossExt,
]
# Indicates which inputs of primitives need to be converted
AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})
# Primitives in inner amp black list will not be converted in O2/O3
_INNER_AMP_BLACK_LIST = []
MS_AMP_BY_REWRITE = False
def amp_cast(value, dtype):
"""This function is used to insert cast operators for tensors during auto mixed precision."""
if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type:
return P.Cast()(value, dtype)
return value
_amp_cast_op = amp_cast
class _OutputTo16(nn.Cell):
"""Wrap cell for amp. Cast network output back to float16."""
def __init__(self, backbone, dtype=mstype.float16):
super(_OutputTo16, self).__init__(auto_prefix=False)
self._backbone = backbone
self.dtype = dtype
self._get_attr_from_cell(backbone)
def construct(self, *args, **kwargs):
return F.cast(self._backbone(*args, **kwargs), self.dtype)
class _OutputTo32(nn.Cell):
"""Wrap loss for amp. Cast network output back to float32."""
def __init__(self, backbone):
super(_OutputTo32, self).__init__(auto_prefix=False)
self._backbone = backbone
self._get_attr_from_cell(backbone)
def construct(self, *args, **kwargs):
out = self._backbone(*args, **kwargs)
return F.mixed_precision_cast(mstype.float32, out)
def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
"""
Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied:
1) Type of node is CallPrimitive and type of instance is Primitive
2) Type of instance is not P.Cast
3) force_cast is True, which means one of upper layer cells is under casting
4) white_list exist and type of node is in white_list
5) black_list exist and type of node is in not black_list
"""
if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive:
return False
if not inspect.isclass(node.get_instance_type()):
return False
if not issubclass(node.get_instance_type(), Primitive):
return False
if issubclass(node.get_instance_type(), P.Cast):
return False
if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
return False
if force_cast:
return True
if white_list is not None and node.get_instance_type() in white_list:
return True
if black_list is not None and node.get_instance_type() not in black_list:
return True
return False
def _precision_set_by_user(cell_inst: nn.Cell) -> bool:
"""Check whether cell precision is set by user."""
for flag in ["fp32", "fp16", "bf16"]:
if hasattr(cell_inst, flag) and getattr(cell_inst, flag):
return True
return False
def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
"""
Check whether current node is type of tree whose network needs to be casted. Follow conditions need to
be satisfied:
1) Type of node is Tree and type of instance is Cell
2) Cell.to_float(xxx) is not set by user
3) force_cast is True, which means one of upper layer networks is under casting
4) white_list exist and type of node is in white_list
5) black_list exist and type of node is in not black_list
"""
if node.get_node_type() != ms.rewrite.NodeType.Tree:
return False
if not inspect.isclass(node.get_instance_type()):
return False
if not issubclass(node.get_instance_type(), nn.Cell):
return False
if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
return False
if _precision_set_by_user(node.get_instance()):
return False
if force_cast:
return True
if white_list is not None and node.get_instance_type() in white_list:
return True
if black_list is not None and node.get_instance_type() not in black_list:
return True
return False
def _insert_cast_for_operator(node, dtype):
"""insert cast pair for node."""
dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
stree = node.get_symbol_tree()
# insert cast fp16/bf16 for inputs of node
for idx, arg in enumerate(node.get_args()):
if arg.type != ms.rewrite.ValueType.NamingValue:
continue
incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"])
arg_providers = node.get_arg_providers()
if not arg_providers or idx not in arg_providers or \
len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1:
# create new target names when argument is used by other node
incast_targets = [stree.unique_name(f"{arg.value}_var")]
else:
incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args)
stree.insert(stree.before(node), incast_node)
node.set_arg_by_node(idx, incast_node)
# insert cast fp32 for outputs of node
for _, target in enumerate(node.get_targets()):
if target.type != ms.rewrite.ValueType.NamingValue:
continue
outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"],
[target.scope, "mindspore"])
outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope])
outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args)
stree.insert(stree.after(node), outcast_node)
def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None):
"""insert cast for operators not in black_list."""
# get all nodes of stree exclude nodes in subtree.
all_nodes = stree.all_nodes(False)
for node in all_nodes:
if not node.get_targets():
continue
if _operator_need_cast(node, force_cast, white_list, black_list):
_insert_cast_for_operator(node, dtype)
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list)
if not _precision_set_by_user(node.get_instance()):
subtree = node.get_sub_tree()
_insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list)
def _need_removed_cast_pair(node, dtype):
"""check whether the cast pairs should be removed."""
dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"])
cast_dtype_f16 = cast_dtypes[0]
cast_dtype_f32 = cast_dtypes[1]
# current node should be cast fp32
if node.get_instance_type() != _amp_cast_op:
return False
node_cast_type = node.get_args()[1]
if node_cast_type != cast_dtype_f32:
return False
# all user nodes should be cast fp16/bf16
if not node.get_users():
return False
all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
for user in node.get_users():
# If ControlFlow node(e.g. if, for, while) exists between current node and user node,
# cast pair should not be removed.
middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
return False
if user.get_instance_type() != _amp_cast_op:
return False
user_cast_type = user.get_args()[1]
if user_cast_type != cast_dtype_f16:
return False
# cast pair detected, check next user
continue
return True
def _remove_duplicated_cast(stree, dtype):
"""remove the duplicated cast operators."""
all_nodes = list(stree.nodes(all_nodes=True))
for node in all_nodes:
if _need_removed_cast_pair(node, dtype):
incast_nodes = node.get_users()
# remove cast fp16/bf16 nodes
for incast_node in incast_nodes:
# get_target_users() return {target0: [(user0, arg_idx), ...], ...}
target_users = list(incast_node.get_target_users().values())
if not target_users or not target_users[0]:
continue
for user_node, arg_idx in target_users[0]:
user_node.set_arg(arg_idx, incast_node.get_args()[0])
stree.erase(incast_node)
# remove the cast fp32 node
stree.erase(node)
def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None):
"""Implement auto mixed precision by rewrite"""
if (white_list is None and black_list is None) or (white_list is not None and black_list is not None):
raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.")
# enable rewrite configs for amp
ms.rewrite.common.namespace._ms_cells_to_subtree = True
ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True
# insert casts by rewrite
stree = ms.rewrite.SymbolTree.create(network)
_insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list)
_remove_duplicated_cast(stree, dtype)
new_net = stree.get_network()
# disable rewrite configs
ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False
ms.rewrite.common.namespace._ms_cells_to_subtree = False
ms.rewrite.common.config.clear_caches()
return new_net
def _auto_black_list(network, black_list, dtype):
"""process the black list of network."""
network.to_float(dtype)
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
if isinstance(subcell, tuple(black_list)):
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
change = True
else:
_auto_black_list(subcell, black_list, dtype)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
return network
class amp_decorator:
"""
Auto mixed precision decorator.
Type of lists: List[Tuple[str, List[int]]]
"""
def __init__(self, amp_level, amp_dtype, white_list, black_list):
self.amp_level = amp_level
self.amp_dtype = amp_dtype
self.white_list = white_list
self.black_list = black_list
def __enter__(self):
push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
pop_amp_strategy()
def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
"""
Set auto mixed precision context decorator for object.
Type of lists: List[Tuple[str, List[int]]]
"""
if inspect.isfunction(obj) or inspect.ismethod(obj):
@functools.wraps(obj)
def wrapper(*args, **kwargs):
with amp_decorator(amp_level, amp_dtype, white_list, black_list):
return obj(*args, **kwargs)
return wrapper
if isinstance(obj, nn.Cell):
obj.construct = types.MethodType(
_set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
return obj
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")
[docs]def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
"""
Returns a network processed with auto mixed precision.
This interface will automatically perform mixed-precision processing on the input network, and the cells
and operators in the processed network will add precision conversion operations to calculate with lower
precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
converted to lower precision float, and calculation results are converted back to full precision float,
i.e. ``mstype.float32`` .
The `amp_level` and its corresponding lists determine which cells and operators are converted.
When `amp_level` is set to ``O0``, no cells and operators are converted.
When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision
operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`.
When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the
list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`.
When `amp_level` is set to ``O3``, all cells will be converted to low precision.
When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision
operations, operators in `auto_blacklist` will be converted to full precision operations, operators in
`promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators
not listed will run in the type defined by their inputs.
Operators in `auto_whitelist` are:
``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``,
``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm``
Operators in `auto_blacklist` are:
``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
``Norm``
Operators in `promote_list` are:
``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
``BiasAdd``
For details on automatic mixed precision, refer to
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
Note:
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
can result in a larger network hierarchy and slower performance.
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
- When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you
may need to manually convert the type to avoid type inconsistency errors of the loss function.
- When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy
specified by `to_float` takes effect first.
.. warning::
``auto`` level of `amp_level` is an experimental API that is subject to change or deletion.
Args:
network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level`
is set to ``auto`` .
amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
- "O0": Do not change.
- "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
precision operations for the rest.
- "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
to lower precision operations.
- "O3": Cast network to lower precision.
- "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in
`auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted
to the higher accuracy float type of the operator inputs, and operators not listed will run in the
type defined by their inputs.
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
default: ``mstype.float16`` .
Raises:
TypeError: If `network` is not a Cell or a function.
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
ValueError: If `amp_level` is not within the supported range.
Examples:
>>> from mindspore import amp
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> amp_level = "O1"
>>> net = amp.auto_mixed_precision(network, amp_level)
"""
if not isinstance(network, nn.Cell):
if amp_level == "auto":
if not inspect.isfunction(network) and not inspect.ismethod(network):
raise TypeError("For amp_level 'auto', the network type should be Cell or function.")
# function is supported for amp_level 'auto'
else:
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.")
if dtype not in (mstype.float16, mstype.bfloat16):
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
if amp_level == "O0":
return network
# Return network if the same amp level has already been configurated
if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"):
logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
f"degradation.")
if amp_level == "O1":
network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST)
elif amp_level == "O2":
if MS_AMP_BY_REWRITE:
network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST)
else:
network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
network = _OutputTo32(network)
elif amp_level == "O3":
if MS_AMP_BY_REWRITE:
network = _auto_mixed_precision_rewrite(network, dtype, black_list=[])
else:
network.to_float(dtype)
network = _OutputTo32(network)
elif amp_level == "auto":
white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST]
black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST]
# set amp_strategy attribute for the object
amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list)
setattr(network, "amp_strategy", amp_strategy)
# set amp_strategy context decorator for the object
network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list)
else:
raise ValueError(f"The amp level {amp_level} is not supported")
setattr(network, "_amp_level", amp_level)
return network
def _do_keep_batchnorm_fp32(network):
"""Do keep batchnorm fp32."""
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
change = True
else:
_do_keep_batchnorm_fp32(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
_config_level = {
"O0": {
"keep_batchnorm_fp32": False,
"cast_model_type": mstype.float32,
"loss_scale_manager": None},
"O1": {
"keep_batchnorm_fp32": False,
"cast_model_type": mstype.float32,
"loss_scale_manager": None},
"O2": {
"keep_batchnorm_fp32": True,
"cast_model_type": mstype.float16,
"loss_scale_manager": DynamicLossScaleManager()},
"O3": {
"keep_batchnorm_fp32": False,
"cast_model_type": mstype.float16,
"loss_scale_manager": None},
"auto": {
"keep_batchnorm_fp32": False,
"cast_model_type": mstype.float32,
"loss_scale_manager": None}}
def _check_kwargs(key_words):
"""Check kwargs."""
for arg in key_words:
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
raise ValueError(f"Unsupported arg '{arg}'")
if 'cast_model_type' in key_words:
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
[mstype.float16, mstype.float32], None)
if 'keep_batchnorm_fp32' in key_words:
validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
if 'loss_scale_manager' in key_words:
loss_scale_manager = key_words['loss_scale_manager']
if loss_scale_manager:
validator.check_value_type('loss_scale_manager', loss_scale_manager,
[LossScaleManager, boost.GroupLossScaleManager])
def _check_level(level, boost_level):
"""Check level."""
if not isinstance(level, str):
raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],"
f"but got type {type(level)}.")
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
enable_boost = False
if boost_level in ["O1", "O2"]:
enable_boost = True
return level, enable_boost
def _add_loss_network(network, loss_fn, cast_model_type):
"""Add loss network."""
class WithLossCell(nn.Cell):
"""Wrap loss for amp. Cast network output back to float32."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
self._get_attr_from_cell(backbone)
def construct(self, data, label):
out = self._backbone(data)
label = F.mixed_precision_cast(mstype.float32, label)
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
if cast_model_type in (mstype.float16, mstype.bfloat16) or \
(hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")):
network = WithLossCell(network, loss_fn)
else:
network = nn.WithLossCell(network, loss_fn)
return network
def _is_grad_accumulation(mcell):
if mcell.cls_name == "GradAccumulationCell":
return True
for cell in mcell.cells():
if _is_grad_accumulation(cell):
return True
return False
def _auto_mixed_precision_process(network, config, level):
"""Auto mixed precision process."""
if MS_AMP_BY_REWRITE:
if config["cast_model_type"] == mstype.float16 or level == "O2":
level = "O2" if config["keep_batchnorm_fp32"] else "O3"
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
# cast_model_type set by kwargs
level = "O0"
network = auto_mixed_precision(network, level)
else:
if config["cast_model_type"] == mstype.float16:
network.to_float(mstype.float16)
if config["keep_batchnorm_fp32"]:
_do_keep_batchnorm_fp32(network)
elif not config["keep_batchnorm_fp32"] and level == "O2":
network.to_float(mstype.float16)
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
pass
else:
network = auto_mixed_precision(network, level)
return network
[docs]def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.
Note:
- After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
to perform the precision conversion again. If `build_train_network` is used to train a converted network,
`level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
Args:
network (Cell): Definition of the network.
optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter.
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
Default: ``None`` .
level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`.
Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level`
setting may be overwritten by settings in `kwargs`.
boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
- 'O0': Do not change.
- 'O1': Enable the boost mode, the performance is improved by about 20%, and
the accuracy is the same as the original accuracy.
- 'O2': Enable the boost mode, the performance is improved by about 30%, and
the accuracy is reduced by less than 3%.
If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
to the type determined by `level` setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If
set, the `level` setting will take no effect on this property.
loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of
:class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will
take no effect on this property.
Raises:
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or
:class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ).
Examples:
>>> from mindspore import amp, nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> amp_level="O3"
>>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)
"""
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))
level, enable_boost = _check_level(level, boost_level)
_check_kwargs(kwargs)
config = dict(_config_level.get(level), **kwargs)
network = _auto_mixed_precision_process(network, config, level)
if loss_fn:
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
loss_scale = None
if config["loss_scale_manager"] is not None:
loss_scale_manager = config["loss_scale_manager"]
loss_scale = loss_scale_manager.get_loss_scale()
update_cell = loss_scale_manager.get_update_cell()
if update_cell is not None:
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
raise ValueError("Only `loss_scale_manager=None` or "
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
"are supported on device `CPU`. ")
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
network = _TrainGradAccuWithLossScaleCell(network, optimizer,
scale_sense=update_cell).set_train()
elif enable_boost:
network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
scale_sense=update_cell).set_train()
else:
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
scale_sense=update_cell).set_train()
return network
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
network = _TrainGradAccuStepCell(network, optimizer).set_train()
elif enable_boost:
network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
else:
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
return network
[docs]def get_white_list():
"""
Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``.
The current built-in whitelist contents are:
[:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
:class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
:class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
:class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
:class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
:class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
:class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
Returns:
list, A copy of internal white list.
Examples:
>>> from mindspore import amp
>>> white_list = amp.get_white_list()
>>> print(white_list)
[<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
<class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
<class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
<class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
<class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
<class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
<class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
<class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
<class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
<class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
<class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
<class 'mindspore.ops.operations.math_ops.Ger'>]
"""
white_list = AMP_WHITE_LIST.copy()
return white_list
[docs]def get_black_list():
"""
Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``.
The current built-in blacklist contents are:
[:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
:class:`mindspore.nn.LayerNorm`]
Returns:
list, A copy of internal black list.
Examples:
>>> from mindspore import amp
>>> black_list = amp.get_black_list()
>>> print(black_list)
[<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
<class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
"""
black_list = AMP_BLACK_LIST.copy()
return black_list
[docs]def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
"""
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
Only one of `white_list` and `black_list` should be provided.
Note:
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
can result in a larger network hierarchy and slower performance.
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` or `level`
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
- Primitives for blacklist is not support yet.
Args:
network (Cell): Definition of the network.
Keyword Args:
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
white list is not used.
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
black list is not used.
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
default: ``mstype.float16`` .
Returns:
network (Cell), A network supporting mixed precision.
Raises:
TypeError: The network type is not Cell.
ValueError: Neither `white_list` nor `black_list` is provided.
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
ValueError: Both `white_list` and `black_list` are provided.
Examples:
>>> from mindspore import amp, nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> custom_white_list = amp.get_white_list()
>>> custom_white_list.append(nn.Flatten)
>>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
"""
if not isinstance(network, nn.Cell):
raise TypeError("The network type should be Cell.")
if white_list is None and black_list is None:
raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.")
if white_list is not None and black_list is not None:
raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
"at the same time, please provide one or the other.")
if dtype not in (mstype.float16, mstype.bfloat16):
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
if white_list is not None:
_list_check(white_list, "white_list")
network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list)
else:
_list_check(black_list, "black_list")
if MS_AMP_BY_REWRITE:
network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list)
else:
network = _auto_black_list(network, black_list, dtype)
network = _OutputTo32(network)
return network
def _list_check(custom_list: list, list_name: str):
"""
check whether custom list is valid
Raises:
TypeError: The type of custom_list is not list.
TypeError: The element in custom_list is not a class.
TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive'].
"""
if not isinstance(custom_list, list):
raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}")
for elem in custom_list:
if not isinstance(elem, type):
raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
f"but got {elem}")
if list_name == "black_list" and not issubclass(elem, nn.Cell):
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}")
if list_name == 'black_list':
for elem in AMP_BLACK_LIST:
if elem not in custom_list:
logger.warning(f"{elem} is removed from internal black list.")
def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable
"""Configure auto mixed precision."""
global MS_AMP_BY_REWRITE
global _amp_cast_op
if enable_rewrite is not None:
MS_AMP_BY_REWRITE = enable_rewrite
if cast_op is not None:
_amp_cast_op = cast_op