# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Parallel Config for the Parallel Training
This is an experimental interface that is subject to change and/or deletion.
"""
from mindspore._checkparam import Validator
from mindspore import context
import mindspore.communication.management as D
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from mindspore import log as logger
__all__ = [
"OpParallelConfig"
]
class _Config:
r""" A basic class of the configure"""
def __str__(self):
info = "[ParallelConfig]" + '\n'
for k, v in self.__dict__.items():
var_info = "{}:{}\n".format(k, v)
info += var_info
return info
[docs]class OpParallelConfig(_Config):
r"""
OpParallelConfig for the setting data parallel and model parallel.
Args:
data_parallel (int): The data parallel way. Default: 1
model_parallel (int): The model parallel way. Default: 1
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore.parallel.nn import OpParallelConfig
>>> config=OpParallelConfig(data_parallel=1, model_parallel=1)
"""
def __init__(self, data_parallel=1, model_parallel=1):
Validator.check_positive_int(data_parallel, "data_parallel")
Validator.check_positive_int(model_parallel, "model_parallel")
self.data_parallel = data_parallel
self.model_parallel = model_parallel
@property
def data_parallel(self):
return self._data_parallel
@data_parallel.setter
def data_parallel(self, value):
Validator.check_positive_int(value, "data_parallel")
self._data_parallel = value
@property
def model_parallel(self):
return self._model_parallel
@model_parallel.setter
def model_parallel(self, value):
Validator.check_positive_int(value, "model_parallel")
self._model_parallel = value
class _PipeLineConfig(_Config):
r"""
PPConfig for the setting data parallel, model parallel
Args:
pipeline_stage (int): The number of the pipeline stages. Default: 1
micro_batch_num (int): The model parallel way. Default: 1
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
"""
def __init__(self, pipeline_stage=1, micro_batch_num=1):
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
self.pipeline_stage = pipeline_stage
self.micro_batch_num = micro_batch_num
@property
def pipeline_stage(self):
return self._pipeline_stage
@pipeline_stage.setter
def pipeline_stage(self, value):
Validator.check_positive_int(value, "pipeline_stage")
self._pipeline_stage = value
context.set_auto_parallel_context(pipeline_stages=value)
@property
def micro_batch_num(self):
return self._micro_batch_num
@micro_batch_num.setter
def micro_batch_num(self, value):
Validator.check_positive_int(value, "micro_batch_num")
self._micro_batch_num = value
# In case the user doesn't pass a config as args.
default_dpmp_config = OpParallelConfig()
def _check_config(config):
"""
Check if micro_batch_num >= pipeline_stage
"""
# the config pipeline_stage is same with context.pipeline_stage
pipeline_stage = context.get_auto_parallel_context("pipeline_stages")
if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage:
raise ValueError(
f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage "
f"{config.pipeline_stage}"
f" in the config.")
# make sure the following is in auto parallel mode
is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if not is_auto_parallel:
return
device_num = D.get_group_size()
optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
raise ValueError(f"The product of the data parallel {config.data_parallel}, "
f"model parallel {config.model_parallel} "
f"pipeline stages {pipeline_stage} "
f"should be less than device_num {device_num}.")
# the config optimizer_shard is same with context.optimizer_shard
if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
f"optimizer_shard to make them consistent.")