# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data paralell allreduce fusion"""
import ctypes
from mindspore import log as logger
_MAX_GROUP_NAME_LEN = 127
_HCCL_LIB = 'libhccl.so'
def _load_lib():
try:
hccl_lib = ctypes.CDLL(_HCCL_LIB)
except RuntimeError:
logger.error('Get hccl lib error')
return hccl_lib
def _c_str(string):
"""Convert a python string to C string."""
if not isinstance(string, str):
string = string.decode('ascii')
return ctypes.c_char_p(string.encode('utf-8'))
def _c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)
[docs]def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
"""
A function set gradient segment strategy according to the index list.
Note:
In the back propagation,
the fusion of the allreduce operators with a fusion attribute equals 1,
will be performed according to the idxList,
to achieve the effect of parallel between calculation and communication.
Args:
idxList (list): The index list of the gradient.
group (str): The hccl communication group.
Raises:
TypeError: If group is not a python str.
TypeError: If IdxList is not a python list.
TypeError: If type of idxList item is not int.
ValueError: If group name length is out of range.
ValueError: If idxList length is 0.
ValueError: If idxList item is less than 0.
RuntimeError: If allreduce split failed.
"""
try:
lib_ctype = _load_lib()
except RuntimeError:
logger.error('Load HCCL lib failed')
if isinstance(group, (str)):
group_len = len(group)
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if isinstance(idxList, (list)):
idx_len = len(idxList)
if idx_len == 0:
raise ValueError('IdxList length is 0')
else:
raise TypeError('IdxList must be a python list')
for idx in idxList:
if isinstance(idx, (int)):
if idx < 0:
raise ValueError('Idx < 0')
else:
raise TypeError('Idx in idxList is invalid')
c_array_idxList = _c_array(ctypes.c_uint, idxList)
c_idx_num = ctypes.c_uint(len(idxList))
c_group = _c_str(group)
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
if ret != 0:
raise RuntimeError('Allreduce split error')
[docs]def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
"""
A function set gradient segment strategy according to the data size percentage list.
Note:
In the back propagation,
the fusion of the allreduce operators with a fusion attribute equals 1,
will be performed according to dataSizeList,
to achieve the effect of parallel between calculation and communication.
Args:
dataSizeList (list): The data size percentage list of the gradient.
group (str): The hccl communication group.
Raises:
TypeError: If group is not a python str.
TypeError: If dataSizeList is not a python list.
TypeError: If type of dataSizeList item is not int or float.
ValueError: If group name length is out of range.
ValueError: If dataSizeList length is 0.
ValueError: If dataSizeList item is less than 0.
RuntimeError: If allreduce split failed.
"""
try:
lib_ctype = _load_lib()
except RuntimeError:
logger.error('Load HCCL lib failed')
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if isinstance(dataSizeList, (list)):
len_data_size = len(dataSizeList)
if len_data_size == 0:
raise ValueError('DataSizeList length is 0')
else:
raise TypeError('DataSizeList must be a python list')
for dataSize in dataSizeList:
if not isinstance(dataSize, (int, float)):
raise TypeError('DataSize in dataSizeList is invalid')
c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
c_size_num = ctypes.c_uint(len(dataSizeList))
c_group = _c_str(group)
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
if ret != 0:
raise RuntimeError('Allreduce split error')