Source code for mindspore.parallel.dp_allreduce_fusion

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data paralell allreduce fusion"""

import ctypes

from mindspore import log as logger

_MAX_GROUP_NAME_LEN = 127
_HCCL_LIB = 'libhccl.so'


def _load_lib():
    try:
        hccl_lib = ctypes.CDLL(_HCCL_LIB)
    except RuntimeError:
        logger.error('Get hccl lib error')

    return hccl_lib


def _c_str(string):
    """Convert a python string to C string."""
    if not isinstance(string, str):
        string = string.decode('ascii')
    return ctypes.c_char_p(string.encode('utf-8'))


def _c_array(ctype, values):
    """Create ctypes array from a python array."""
    return (ctype * len(values))(*values)


[docs]def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): """ A function set gradient segment strategy according to the index list. Note: In the back propagation, the fusion of the allreduce operators with a fusion attribute equals 1, will be performed according to the idxList, to achieve the effect of parallel between calculation and communication. Args: idxList (list): The index list of the gradient. group (str): The hccl communication group. Raises: TypeError: If group is not a python str. TypeError: If IdxList is not a python list. TypeError: If type of idxList item is not int. ValueError: If group name length is out of range. ValueError: If idxList length is 0. ValueError: If idxList item is less than 0. RuntimeError: If allreduce split failed. """ try: lib_ctype = _load_lib() except RuntimeError: logger.error('Load HCCL lib failed') if isinstance(group, (str)): group_len = len(group) if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') else: raise TypeError('Group must be a python str') if isinstance(idxList, (list)): idx_len = len(idxList) if idx_len == 0: raise ValueError('IdxList length is 0') else: raise TypeError('IdxList must be a python list') for idx in idxList: if isinstance(idx, (int)): if idx < 0: raise ValueError('Idx < 0') else: raise TypeError('Idx in idxList is invalid') c_array_idxList = _c_array(ctypes.c_uint, idxList) c_idx_num = ctypes.c_uint(len(idxList)) c_group = _c_str(group) ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList) if ret != 0: raise RuntimeError('Allreduce split error')
[docs]def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): """ A function set gradient segment strategy according to the data size percentage list. Note: In the back propagation, the fusion of the allreduce operators with a fusion attribute equals 1, will be performed according to dataSizeList, to achieve the effect of parallel between calculation and communication. Args: dataSizeList (list): The data size percentage list of the gradient. group (str): The hccl communication group. Raises: TypeError: If group is not a python str. TypeError: If dataSizeList is not a python list. TypeError: If type of dataSizeList item is not int or float. ValueError: If group name length is out of range. ValueError: If dataSizeList length is 0. ValueError: If dataSizeList item is less than 0. RuntimeError: If allreduce split failed. """ try: lib_ctype = _load_lib() except RuntimeError: logger.error('Load HCCL lib failed') if isinstance(group, (str)): group_len = len(group) if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}') else: raise TypeError('Group must be a python str') if isinstance(dataSizeList, (list)): len_data_size = len(dataSizeList) if len_data_size == 0: raise ValueError('DataSizeList length is 0') else: raise TypeError('DataSizeList must be a python list') for dataSize in dataSizeList: if not isinstance(dataSize, (int, float)): raise TypeError('DataSize in dataSizeList is invalid') c_array_sizeList = _c_array(ctypes.c_float, dataSizeList) c_size_num = ctypes.c_uint(len(dataSizeList)) c_group = _c_str(group) ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList) if ret != 0: raise RuntimeError('Allreduce split error')