# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""comm_helper"""
import os
from mindspore import context
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _is_ps_mode, _get_ps_context
from mindspore import log as logger
from mindspore.communication._hccl_management import load_lib as hccl_load_lib
from mindspore._c_expression import get_rank_id, get_rank_size, CollectiveManager
_HCCL_AVAILABLE = False
_HCCL_TEST_AVAILABLE = False
_NCCL_AVAILABLE = False
_MPI_AVAILABLE = False
try:
import mindspore._ms_mpi as mpi
_NCCL_AVAILABLE = True
except ImportError:
_NCCL_AVAILABLE = False
try:
hccl_load_lib()
_HCCL_AVAILABLE = True
except RuntimeError:
_HCCL_AVAILABLE = False
if _HCCL_AVAILABLE:
import mindspore.communication._hccl_management as hccl
try:
import mindspore._ascend_mpi as mpi
_MPI_AVAILABLE = True
except ImportError:
_MPI_AVAILABLE = False
else:
try:
import hccl_test.manage.api as hccl
_HCCL_AVAILABLE = True
_HCCL_TEST_AVAILABLE = True
except ImportError:
_HCCL_AVAILABLE = False
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
MCCL_WORLD_COMM_GROUP = "mccl_world_group"
class Backend:
"""
Class for available backends.
Note:
The backends' value should be string, e.g., "hccl".
If backend is set to Backend.UNDEFINED, it will be seen as invaliad.
Args:
name (str): The name of backend.
Raises:
TypeError: If name is not a string.
ValueError: If backend is invalid.
Examples:
>>> Backend("abc")
>>> hccl = Backend("hccl")
"""
UNDEFINED = "undefined"
HCCL = "hccl"
NCCL = "nccl"
HCCL_MPI = "hccl_mpi"
MCCL = "mccl"
def __new__(cls, name):
"""Create instance object of Backend."""
if not isinstance(name, str):
raise TypeError("For 'Backend', the class variable 'name' must be a string, "
"but got the type : {}".format(type(name)))
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
if value == Backend.UNDEFINED:
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
"please use hccl or nccl.".format(name))
return value
DEFAULT_BACKEND = Backend("hccl")
[文档]class GlobalComm:
"""
World communication information. The GlobalComm is a global class. The members contain:
- BACKEND: The communication library used, using HCCL/NCCL.
- WORLD_COMM_GROUP: Global communication domain.
"""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True
class _ExistingGroup:
"""
The communication groups which exist in the progress.
"""
ITEMS = {}
def is_hccl_available():
"""
Check HCCL api is available.
Returns:
Boolean. Return whether HCCL is available or not.
"""
return _HCCL_AVAILABLE
def is_mpi_available():
"""
Check HCCL & MPI api is available.
Returns:
Boolean. Return whether HCCL & MPI is available or not.
"""
return _MPI_AVAILABLE
def is_nccl_available():
"""
Check NCCL api is available.
Returns:
Boolean. Return whether NCCL is available or not.
"""
return _NCCL_AVAILABLE
def _check_mpi_envs():
"""
Check whether mpi environment variables have been exported or not.
return True if mpi environment variables have been exported, False otherwise.
"""
ompi_command_env = os.getenv("OMPI_COMMAND")
pmix_rank_env = os.getenv("PMIX_RANK")
if ompi_command_env and pmix_rank_env:
return True
return False
def _use_old_ps():
"""
Whether use old framework to launch Parameter Server training.
"""
return os.getenv("USE_OLD_PS") == "True"
def _not_require_collective_comm_lib():
'''
Whether collective communication library is required in this training mode.
For example, scheduler and server do not require actual collective communication in parameter server mode.
'''
# Environment variable USE_OLD_PS is set by user and used to run
# parameter server training with old framework.
if _use_old_ps() and (_is_role_sched() or _is_role_pserver()):
return True
return False
def _check_bypass_rank_id_and_size():
'''
Whether bypass calling c++ API to get rank id and size, instead, use fake rank id 0 and rank size 1.
This returns True when this process is Scheduler node or is Server node in old Parameter Server training mode.
'''
if _is_role_sched():
return True
if _use_old_ps() and _is_role_pserver():
return True
device_target = context.get_context("device_target")
if not _use_old_ps() and _is_ps_mode() and _get_ps_context("worker_num") == 1 and device_target == "Ascend":
return True
return False
def check_parameter_available(func):
"""
Check parameter is available. If not available, raise Error.
Args:
func (Function): The function to be run.
Raises:
RuntimeError.
Returns:
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
if _not_require_collective_comm_lib():
return func(*args, **kargs)
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
group = None
if "group" in kargs.keys():
group = kargs.get("group")
if group is not None and not isinstance(group, str):
raise TypeError("The parameter 'group' should be str or None, "
"but got the type : {}".format(type(group)))
if "backend" in kargs.keys():
backend = kargs.get("backend")
if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available():
raise RuntimeError("Distributed Communication doesn't have MPI built in")
if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in")
if group is None:
if backend is Backend.HCCL or Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP
return func(*args, **kargs)
return wrapper
@check_parameter_available
def _get_rank_helper(group, backend):
"""
The Helper to do get_rank_id.
Args:
group (str): The communication group.
backend (str): The backend, like "hccl".
Raises:
ValueError: If backend is invalid.
Returns:
Integer. The local rank id of the calling process.
"""
if _check_bypass_rank_id_and_size():
rank_id = 0
return rank_id
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
rank_id = hccl.get_rank_id()
else:
rank_id = hccl.get_rank_id(group)
elif backend == Backend.NCCL:
rank_id = get_rank_id(group)
elif backend == Backend.MCCL:
# Call cluster getting rank function.
rank_id = CollectiveManager.get_instance().get_rank_id(group)
else:
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend))
return rank_id
@check_parameter_available
def _get_local_rank_helper(group, backend):
"""
The Helper to do get_local_rank_id.
Args:
group (str): The communication group.
backend (str): The backend, like "hccl".
Raises:
ValueError: If backend is invalid.
Returns:
Integer. The local rank id of the calling process.
"""
rank_id = None
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
rank_id = hccl.get_local_rank_id()
else:
rank_id = hccl.get_local_rank_id(group)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
else:
raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi or hccl.".format(backend))
return rank_id
@check_parameter_available
def _get_size_helper(group, backend):
"""
The Helper to do get_rank_size.
Args:
group (str): The communication group.
backend (str): The backend, like "hccl".
Raises:
ValueError: If backend is invalid.
Returns:
Integer. The rank size of specified group.
"""
size = None
if _check_bypass_rank_id_and_size():
size = 1
return size
if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_rank_size()
else:
size = hccl.get_rank_size(group)
elif backend == Backend.NCCL:
size = get_rank_size(group)
elif backend == Backend.MCCL:
# Call cluster getting group size function.
size = CollectiveManager.get_instance().get_group_size(group)
else:
raise ValueError("For '_get_size_helper', the argument 'backend' {} is not supported, "
"please use hccl or nccl.".format(backend))
return size
@check_parameter_available
def _get_local_size_helper(group, backend):
"""
The Helper to do get_local_rank_size.
Args:
group (str): The communication group.
backend (str): The backend, like "hccl".
Raises:
ValueError: If backend is invalid.
Returns:
Integer. The local rank size where the calling process is being within specified group.
"""
size = None
if backend == Backend.HCCL_MPI:
size = mpi.get_local_rank_size(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size()
else:
size = hccl.get_local_rank_size(group)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
else:
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
"please use hccl.".format(backend))
return size
@check_parameter_available
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
"""
The Helper to do get_world_rank_from_group_rank.
Args:
group (str): The user communication group.
group_rank_id (int): A rank id in user communication group.
backend (str): The backend, like "hccl".
Raises:
TypeError: If group_rank_id is not int.
ValueError: If group is "hccl_world_group" or backend is invalid.
Returns:
Integer. A rank id in world communication group.
"""
world_rank_id = None
if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
if backend == Backend.HCCL_MPI:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank', the argument 'group' "
"should not be 'HCCL_WORLD_COMM_GROUP'.")
world_rank_id = mpi.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
return world_rank_id
@check_parameter_available
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
"""
The Helper to do get_group_rank_from_world_rank.
Args:
world_rank_id (int): A rank id in world communication group.
group (str): The user communication group.
backend (str): The backend, like "hccl".
Raises:
TypeError: If world_rank_id is not int.
ValueError: If group is 'hccl_world_group' or backend is invalid.
Returns:
Integer. A rank id in user communication group.
"""
group_rank_id = None
if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
if backend == Backend.HCCL_MPI:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank', the argument 'group' "
"should not be 'HCCL_WORLD_COMM_GROUP'.")
group_rank_id = mpi.get_group_rank_from_world_rank(world_rank_id, group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
return group_rank_id
@check_parameter_available
def _create_group_helper(group, rank_ids, backend):
"""
The Helper to do create_group.
Args:
group (str): The communication group.
rank_ids (list): Rank ids in the group.
backend (str): The backend, like "hccl".
Raises:
TypeError: If rank_ids is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
"""
if group in _ExistingGroup.ITEMS.keys():
if rank_ids != _ExistingGroup.ITEMS[group]:
raise ValueError("The group {} has been created, the rank_list is {}, "
"but current rank_list for the group is {}".
format(group, _ExistingGroup.ITEMS[group], rank_ids))
logger.warning("%r group has existed.", group)
return
if backend == Backend.HCCL:
if not isinstance(rank_ids, list):
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
rank_size = len(rank_ids)
if rank_size < 1:
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
"but got 'rank_ids' size : {}.".format(len(rank_ids)))
if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI:
mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support create_group now.")
else:
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
_ExistingGroup.ITEMS[group] = rank_ids
@check_parameter_available
def _destroy_group_helper(group, backend):
"""
The Helper to do destroy_group.
Args:
group (str): The user communication group.
backend (str): The backend, like "hccl".
Raises:
ValueError: If group is "hccl_world_group" or backend is invalid.
"""
if backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("The hccl_world_group does not support destruction.")
hccl.destroy_group(group)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support destroy_group now.")
else:
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))