Source code for mindspore.train.amp

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto mixed precision."""
from __future__ import absolute_import
import inspect
import types
from typing import Any
import functools
import collections

import mindspore as ms
from mindspore import nn
from mindspore import _checkparam as validator
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_pipeline_stages
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from mindspore import boost, context
from mindspore.ops import operations as P
from mindspore.ops import Primitive
from mindspore.ops import auto_generate as gen
from mindspore import log as logger
from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel

AMP_WHITE_LIST = [
    nn.Conv1d,
    nn.Conv2d,
    nn.Conv3d,
    nn.Conv1dTranspose,
    nn.Conv2dTranspose,
    nn.Conv3dTranspose,
    nn.Dense,
    nn.LSTMCell,
    nn.RNNCell,
    nn.GRUCell,
    P.Conv2D,
    P.Conv3D,
    P.Conv2DTranspose,
    P.Conv3DTranspose,
    P.Conv2DBackpropInput,
    P.MatMul,
    P.BatchMatMul,
    P.PReLU,
    P.ReLU,
    P.Ger,
]

AMP_BLACK_LIST = [
    nn.BatchNorm1d,
    nn.BatchNorm2d,
    nn.BatchNorm3d,
    nn.LayerNorm,
]

AMP_AUTO_WHITE_LIST = [
    P.Conv2D,
    P.Conv3D,
    P.Conv2DTranspose,
    P.Conv3DTranspose,
    gen.Convolution,
    P.MatMul,
    gen.MatMulExt,
    P.BatchMatMul,
    gen.BatchMatMulExt,
    gen.PReLU,
    P.Einsum,
    gen.Dense,
    gen.Addmm,
]

AMP_AUTO_BLACK_LIST = [
    gen.Pow,
    gen.ACos,
    gen.Asin,
    gen.Cosh,
    P.Erfinv,
    P.Exp,
    P.Expm1,
    P.Log,
    P.Log1p,
    P.Reciprocal,
    P.Rsqrt,
    P.Sinh,
    P.Tan,
    P.Softplus,
    gen.SoftplusExt,
    P.LayerNorm,
    gen.LayerNormExt,
    P.BatchNorm,
    gen.GroupNorm,
    P.KLDivLoss,
    P.SmoothL1Loss,
    P.MultilabelMarginLoss,
    P.SoftMarginLoss,
    P.TripletMarginLoss,
    P.MultiMarginLoss,
    P.BCEWithLogitsLoss,
    P.Pdist,
    P.Cdist,
    P.Renorm,
]

# Indicates which inputs of primitives need to be converted
AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})

# Primitives in inner amp black list will not be converted in O2/O3
_INNER_AMP_BLACK_LIST = []

MS_AMP_BY_REWRITE = False


def amp_cast(value, dtype):
    """This function is used to insert cast operators for tensors during auto mixed precision."""
    if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type:
        return P.Cast()(value, dtype)
    return value

_amp_cast_op = amp_cast


class _OutputTo16(nn.Cell):
    """Wrap cell for amp. Cast network output back to float16."""
    def __init__(self, backbone, dtype=mstype.float16):
        super(_OutputTo16, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.dtype = dtype
        self._get_attr_from_cell(backbone)

    def construct(self, *args, **kwargs):
        return F.cast(self._backbone(*args, **kwargs), self.dtype)


class _OutputTo32(nn.Cell):
    """Wrap loss for amp. Cast network output back to float32."""
    def __init__(self, backbone):
        super(_OutputTo32, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self._get_attr_from_cell(backbone)

    def construct(self, *args, **kwargs):
        out = self._backbone(*args, **kwargs)
        return F.mixed_precision_cast(mstype.float32, out)


def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
    """
    Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied:
        1) Type of node is CallPrimitive and type of instance is Primitive
        2) Type of instance is not P.Cast
        3) force_cast is True, which means one of upper layer cells is under casting
        4) white_list exist and type of node is in white_list
        5) black_list exist and type of node is in not black_list
    """
    if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive:
        return False
    if not inspect.isclass(node.get_instance_type()):
        return False
    if not issubclass(node.get_instance_type(), Primitive):
        return False
    if issubclass(node.get_instance_type(), P.Cast):
        return False
    if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
        return False
    if force_cast:
        return True
    if white_list is not None and node.get_instance_type() in white_list:
        return True
    if black_list is not None and node.get_instance_type() not in black_list:
        return True
    return False


def _precision_set_by_user(cell_inst: nn.Cell) -> bool:
    """Check whether cell precision is set by user."""
    for flag in ["fp32", "fp16", "bf16"]:
        if hasattr(cell_inst, flag) and getattr(cell_inst, flag):
            return True
    return False


def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
    """
    Check whether current node is type of tree whose network needs to be casted. Follow conditions need to
    be satisfied:
        1) Type of node is Tree and type of instance is Cell
        2) Cell.to_float(xxx) is not set by user
        3) force_cast is True, which means one of upper layer networks is under casting
        4) white_list exist and type of node is in white_list
        5) black_list exist and type of node is in not black_list
    """
    if node.get_node_type() != ms.rewrite.NodeType.Tree:
        return False
    if not inspect.isclass(node.get_instance_type()):
        return False
    if not issubclass(node.get_instance_type(), nn.Cell):
        return False
    if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
        return False
    if _precision_set_by_user(node.get_instance()):
        return False
    if force_cast:
        return True
    if white_list is not None and node.get_instance_type() in white_list:
        return True
    if black_list is not None and node.get_instance_type() not in black_list:
        return True
    return False


def _insert_cast_for_operator(node, dtype):
    """insert cast pair for node."""
    dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
    stree = node.get_symbol_tree()
    # insert cast fp16/bf16 for inputs of node
    for idx, arg in enumerate(node.get_args()):
        if arg.type != ms.rewrite.ValueType.NamingValue:
            continue
        incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"])
        arg_providers = node.get_arg_providers()
        if not arg_providers or idx not in arg_providers or \
            len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1:
            # create new target names when argument is used by other node
            incast_targets = [stree.unique_name(f"{arg.value}_var")]
        else:
            incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
        incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args)
        stree.insert(stree.before(node), incast_node)
        node.set_arg_by_node(idx, incast_node)
    # insert cast fp32 for outputs of node
    for _, target in enumerate(node.get_targets()):
        if target.type != ms.rewrite.ValueType.NamingValue:
            continue
        outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"],
                                                                 [target.scope, "mindspore"])
        outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope])
        outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args)
        stree.insert(stree.after(node), outcast_node)


def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None):
    """insert cast for operators not in black_list."""
    # get all nodes of stree exclude nodes in subtree.
    all_nodes = stree.all_nodes(False)
    for node in all_nodes:
        if not node.get_targets():
            continue
        if _operator_need_cast(node, force_cast, white_list, black_list):
            _insert_cast_for_operator(node, dtype)
        elif node.get_node_type() == ms.rewrite.NodeType.Tree:
            force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list)
            if not _precision_set_by_user(node.get_instance()):
                subtree = node.get_sub_tree()
                _insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list)


def _need_removed_cast_pair(node, dtype):
    """check whether the cast pairs should be removed."""
    dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
    cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"])
    cast_dtype_f16 = cast_dtypes[0]
    cast_dtype_f32 = cast_dtypes[1]
    # current node should be cast fp32
    if node.get_instance_type() != _amp_cast_op:
        return False
    node_cast_type = node.get_args()[1]
    if node_cast_type != cast_dtype_f32:
        return False
    # all user nodes should be cast fp16/bf16
    if not node.get_users():
        return False
    all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
    for user in node.get_users():
        # If ControlFlow node(e.g. if, for, while) exists between current node and user node,
        # cast pair should not be removed.
        middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
        if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
            return False
        if user.get_instance_type() != _amp_cast_op:
            return False
        user_cast_type = user.get_args()[1]
        if user_cast_type != cast_dtype_f16:
            return False
        # cast pair detected, check next user
        continue
    return True


def _remove_duplicated_cast(stree, dtype):
    """remove the duplicated cast operators."""
    all_nodes = list(stree.nodes(all_nodes=True))
    for node in all_nodes:
        if _need_removed_cast_pair(node, dtype):
            incast_nodes = node.get_users()
            # remove cast fp16/bf16 nodes
            for incast_node in incast_nodes:
                # get_target_users() return {target0: [(user0, arg_idx), ...], ...}
                target_users = list(incast_node.get_target_users().values())
                if not target_users or not target_users[0]:
                    continue
                for user_node, arg_idx in target_users[0]:
                    user_node.set_arg(arg_idx, incast_node.get_args()[0])
                stree.erase(incast_node)
            # remove the cast fp32 node
            stree.erase(node)


def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None):
    """Implement auto mixed precision by rewrite"""
    if (white_list is None and black_list is None) or (white_list is not None and black_list is not None):
        raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.")
    # enable rewrite configs for amp
    ms.rewrite.common.namespace._ms_cells_to_subtree = True
    ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True
    # insert casts by rewrite
    stree = ms.rewrite.SymbolTree.create(network)
    _insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list)
    _remove_duplicated_cast(stree, dtype)
    new_net = stree.get_network()
    # disable rewrite configs
    ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False
    ms.rewrite.common.namespace._ms_cells_to_subtree = False
    ms.rewrite.common.config.clear_caches()
    return new_net


def _auto_black_list(network, black_list, dtype):
    """process the black list of network."""
    network.to_float(dtype)
    cells = network.name_cells()
    change = False
    for name in cells:
        subcell = cells[name]
        if subcell == network:
            continue
        if isinstance(subcell, tuple(black_list)):
            network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
            change = True
        else:
            _auto_black_list(subcell, black_list, dtype)
    if isinstance(network, nn.SequentialCell) and change:
        network.cell_list = list(network.cells())
    return network


class amp_decorator:
    """
    Auto mixed precision decorator.
    Type of lists: List[Tuple[str, List[int]]]
    """
    def __init__(self, amp_level, amp_dtype, white_list, black_list):
        self.amp_level = amp_level
        self.amp_dtype = amp_dtype
        self.white_list = white_list
        self.black_list = black_list

    def __enter__(self):
        push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
        pop_amp_strategy()


def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
    """
    Set auto mixed precision context decorator for object.
    Type of lists: List[Tuple[str, List[int]]]
    """
    if inspect.isfunction(obj) or inspect.ismethod(obj):
        @functools.wraps(obj)
        def wrapper(*args, **kwargs):
            with amp_decorator(amp_level, amp_dtype, white_list, black_list):
                return obj(*args, **kwargs)
        return wrapper
    if isinstance(obj, nn.Cell):
        obj.construct = types.MethodType(
            _set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
        return obj
    raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")


[docs]def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16): """ Returns a network processed with auto mixed precision. This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with lower precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are converted to lower precision float, and calculation results are converted back to full precision float, i.e. ``mstype.float32`` . The `amp_level` and its corresponding lists determine which cells and operators are converted. When `amp_level` is set to ``O0``, no cells and operators are converted. When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`. When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`. When `amp_level` is set to ``O3``, all cells will be converted to low precision. When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision operations, operators in `auto_blacklist` will be converted to full precision operations, operators in `promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators not listed will run in the type defined by their inputs. Operators in `auto_whitelist` are: ``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``, ``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm`` Operators in `auto_blacklist` are: ``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``, ``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``, ``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``, ``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``, ``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``, ``Norm`` Operators in `promote_list` are: ``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``, ``BiasAdd`` For details on automatic mixed precision, refer to `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.4.0/beginner/mixed_precision.html>`_ . Note: - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`, can result in a larger network hierarchy and slower performance. - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion. - When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you may need to manually convert the type to avoid type inconsistency errors of the loss function. - When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy specified by `to_float` takes effect first. .. warning:: ``auto`` level of `amp_level` is an experimental API that is subject to change or deletion. Args: network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level` is set to ``auto`` . amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` . - "O0": Do not change. - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full precision operations for the rest. - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest to lower precision operations. - "O3": Cast network to lower precision. - "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in `auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators not listed will run in the type defined by their inputs. dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` , default: ``mstype.float16`` . Raises: TypeError: If `network` is not a Cell or a function. ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` . ValueError: If `amp_level` is not within the supported range. Examples: >>> from mindspore import amp >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.0/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level) """ if not isinstance(network, nn.Cell): if amp_level == "auto": if not inspect.isfunction(network) and not inspect.ismethod(network): raise TypeError("For amp_level 'auto', the network type should be Cell or function.") # function is supported for amp_level 'auto' else: raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.") if dtype not in (mstype.float16, mstype.bfloat16): raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.") if amp_level == "O0": return network # Return network if the same amp level has already been configurated if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"): logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} " f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance " f"degradation.") if amp_level == "O1": network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST) elif amp_level == "O2": if MS_AMP_BY_REWRITE: network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST) else: network = _auto_black_list(network, AMP_BLACK_LIST, dtype) network = _OutputTo32(network) elif amp_level == "O3": if MS_AMP_BY_REWRITE: network = _auto_mixed_precision_rewrite(network, dtype, black_list=[]) else: network.to_float(dtype) network = _OutputTo32(network) elif amp_level == "auto": white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST] black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST] # set amp_strategy attribute for the object amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list) setattr(network, "amp_strategy", amp_strategy) # set amp_strategy context decorator for the object network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list) else: raise ValueError(f"The amp level {amp_level} is not supported") setattr(network, "_amp_level", amp_level) return network
def _do_keep_batchnorm_fp32(network): """Do keep batchnorm fp32.""" cells = network.name_cells() change = False for name in cells: subcell = cells[name] if subcell == network: continue elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)): network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32)) change = True else: _do_keep_batchnorm_fp32(subcell) if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) _config_level = { "O0": { "keep_batchnorm_fp32": False, "cast_model_type": mstype.float32, "loss_scale_manager": None}, "O1": { "keep_batchnorm_fp32": False, "cast_model_type": mstype.float32, "loss_scale_manager": None}, "O2": { "keep_batchnorm_fp32": True, "cast_model_type": mstype.float16, "loss_scale_manager": DynamicLossScaleManager()}, "O3": { "keep_batchnorm_fp32": False, "cast_model_type": mstype.float16, "loss_scale_manager": None}, "auto": { "keep_batchnorm_fp32": False, "cast_model_type": mstype.float32, "loss_scale_manager": None}} def _check_kwargs(key_words): """Check kwargs.""" for arg in key_words: if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: raise ValueError(f"Unsupported arg '{arg}'") if 'cast_model_type' in key_words: validator.check_type_name('cast_model_type', key_words['cast_model_type'], [mstype.float16, mstype.float32], None) if 'keep_batchnorm_fp32' in key_words: validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) if 'loss_scale_manager' in key_words: loss_scale_manager = key_words['loss_scale_manager'] if loss_scale_manager: validator.check_value_type('loss_scale_manager', loss_scale_manager, [LossScaleManager, boost.GroupLossScaleManager]) def _check_level(level, boost_level): """Check level.""" if not isinstance(level, str): raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto']," f"but got type {type(level)}.") validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN) validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN) enable_boost = False if boost_level in ["O1", "O2"]: enable_boost = True return level, enable_boost def _add_loss_network(network, loss_fn, cast_model_type): """Add loss network.""" class WithLossCell(nn.Cell): """Wrap loss for amp. Cast network output back to float32.""" def __init__(self, backbone, loss_fn): super(WithLossCell, self).__init__(auto_prefix=False) self._backbone = backbone self._loss_fn = loss_fn self._get_attr_from_cell(backbone) def construct(self, data, label): out = self._backbone(data) label = F.mixed_precision_cast(mstype.float32, label) return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) validator.check_value_type('loss_fn', loss_fn, nn.Cell) if cast_model_type in (mstype.float16, mstype.bfloat16) or \ (hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")): network = WithLossCell(network, loss_fn) else: network = nn.WithLossCell(network, loss_fn) return network def _is_grad_accumulation(mcell): if mcell.cls_name == "GradAccumulationCell": return True for cell in mcell.cells(): if _is_grad_accumulation(cell): return True return False def _auto_mixed_precision_process(network, config, level): """Auto mixed precision process.""" if MS_AMP_BY_REWRITE: if config["cast_model_type"] == mstype.float16 or level == "O2": level = "O2" if config["keep_batchnorm_fp32"] else "O3" elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"): # cast_model_type set by kwargs level = "O0" network = auto_mixed_precision(network, level) else: if config["cast_model_type"] == mstype.float16: network.to_float(mstype.float16) if config["keep_batchnorm_fp32"]: _do_keep_batchnorm_fp32(network) elif not config["keep_batchnorm_fp32"] and level == "O2": network.to_float(mstype.float16) elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"): pass else: network = auto_mixed_precision(network, level) return network
[docs]def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs): """ Build the mixed precision training cell automatically. Note: - After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported to perform the precision conversion again. If `build_train_network` is used to train a converted network, `level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion. Args: network (Cell): Definition of the network. optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter. loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside. Default: ``None`` . level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` . For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`. Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in `kwargs`. boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` . - 'O0': Do not change. - 'O1': Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy. - 'O2': Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%. If 'O1' or 'O2' mode is set, the boost related library will take effect automatically. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted to the type determined by `level` setting. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If set, the `level` setting will take no effect on this property. loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of :class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will take no effect on this property. Raises: ValueError: If device is CPU, property `loss_scale_manager` is not `None` or :class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ). Examples: >>> from mindspore import amp, nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.0/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) >>> amp_level="O3" >>> net = amp.build_train_network(network, net_opt, net_loss, amp_level) """ validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt, nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)) level, enable_boost = _check_level(level, boost_level) _check_kwargs(kwargs) config = dict(_config_level.get(level), **kwargs) network = _auto_mixed_precision_process(network, config, level) if loss_fn: network = _add_loss_network(network, loss_fn, config["cast_model_type"]) loss_scale = None if config["loss_scale_manager"] is not None: loss_scale_manager = config["loss_scale_manager"] loss_scale = loss_scale_manager.get_loss_scale() update_cell = loss_scale_manager.get_update_cell() if update_cell is not None: # only cpu not support `TrainOneStepWithLossScaleCell` for control flow. if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": raise ValueError("Only `loss_scale_manager=None` or " "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" "are supported on device `CPU`. ") if _get_pipeline_stages() > 1 or _is_grad_accumulation(network): network = _TrainGradAccuWithLossScaleCell(network, optimizer, scale_sense=update_cell).set_train() elif enable_boost: network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer, scale_sense=update_cell).set_train() else: network = nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense=update_cell).set_train() return network if _get_pipeline_stages() > 1 or _is_grad_accumulation(network): network = _TrainGradAccuStepCell(network, optimizer).set_train() elif enable_boost: network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train() else: network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() return network
[docs]def get_white_list(): """ Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``. The current built-in whitelist contents are: [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`, :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`, :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`, :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`, :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`, :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`, :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`] Returns: list, A copy of internal white list. Examples: >>> from mindspore import amp >>> white_list = amp.get_white_list() >>> print(white_list) [<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>, <class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>, <class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>, <class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>, <class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>, <class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>, <class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>, <class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>, <class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>, <class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>, <class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>, <class 'mindspore.ops.operations.math_ops.Ger'>] """ white_list = AMP_WHITE_LIST.copy() return white_list
[docs]def get_black_list(): """ Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``. The current built-in blacklist contents are: [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`, :class:`mindspore.nn.LayerNorm`] Returns: list, A copy of internal black list. Examples: >>> from mindspore import amp >>> black_list = amp.get_black_list() >>> print(black_list) [<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>, <class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>] """ black_list = AMP_BLACK_LIST.copy() return black_list
[docs]def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16): """ When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion. When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion. Only one of `white_list` and `black_list` should be provided. Note: - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`, can result in a larger network hierarchy and slower performance. - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion. - Primitives for blacklist is not support yet. Args: network (Cell): Definition of the network. white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means white list is not used. black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means black list is not used. dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` , default: ``mstype.float16`` . Returns: network (Cell), A network supporting mixed precision. Raises: TypeError: The network type is not Cell. ValueError: Neither `white_list` nor `black_list` is provided. ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` . ValueError: Both `white_list` and `black_list` are provided. Examples: >>> from mindspore import amp, nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> custom_white_list = amp.get_white_list() >>> custom_white_list.append(nn.Flatten) >>> net = amp.custom_mixed_precision(net, white_list=custom_white_list) """ if not isinstance(network, nn.Cell): raise TypeError("The network type should be Cell.") if white_list is None and black_list is None: raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.") if white_list is not None and black_list is not None: raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided " "at the same time, please provide one or the other.") if dtype not in (mstype.float16, mstype.bfloat16): raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.") if white_list is not None: _list_check(white_list, "white_list") network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list) else: _list_check(black_list, "black_list") if MS_AMP_BY_REWRITE: network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list) else: network = _auto_black_list(network, black_list, dtype) network = _OutputTo32(network) return network
def _list_check(custom_list: list, list_name: str): """ check whether custom list is valid Raises: TypeError: The type of custom_list is not list. TypeError: The element in custom_list is not a class. TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive']. """ if not isinstance(custom_list, list): raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}") for elem in custom_list: if not isinstance(elem, type): raise TypeError(f"The element in {list_name} should be a class, but got {elem}") if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive): raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', " f"but got {elem}") if list_name == "black_list" and not issubclass(elem, nn.Cell): raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}") if list_name == 'black_list': for elem in AMP_BLACK_LIST: if elem not in custom_list: logger.warning(f"{elem} is removed from internal black list.") def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable """Configure auto mixed precision.""" global MS_AMP_BY_REWRITE global _amp_cast_op if enable_rewrite is not None: MS_AMP_BY_REWRITE = enable_rewrite if cast_op is not None: _amp_cast_op = cast_op