# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Configuration of parameters for strategy-searching algorithm in auto_parallel"""
import threading
from mindspore._c_expression import CostModelContext
from mindspore._extends.pynative_helper import args_type_check
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
class _AlgoParameterConfig():
"""
_AlgoParameterConfig is the configuration of setting parameters used in th algorithm.
Note:
Creating a config through instantiating _AlgoParameterConfig object is not recommended.
Use algo_parameter_config() to get the configuration since _AlgoParameterConfig is singleton.
"""
_instance = None
_instance_lock = threading.Lock()
def __init__(self):
self._config_handle = CostModelContext.get_instance()
def check_config_handle(self):
"""
Check config handle.
Raises:
ValueError: If the config handle is none.
"""
if self._config_handle is None:
raise ValueError("Config handle is none!!!")
def set_simplify_cal(self, simplify_cal):
self.check_config_handle()
self._config_handle.set_simplify_cal(simplify_cal)
def get_simplify_cal(self):
self.check_config_handle()
return self._config_handle.get_simplify_cal()
def set_not_fully_use_devices(self, not_fully):
self.check_config_handle()
self._config_handle.set_not_fully_use_devices(not_fully)
def get_not_fully_use_devices(self):
self.check_config_handle()
return self._config_handle.get_not_fully_use_devices()
def set_elementwise_op_strategy_follow(self, element_strategy_follow):
self.check_config_handle()
self._config_handle.set_elementwise_op_strategy_follow(element_strategy_follow)
def get_elementwise_op_strategy_follow(self):
self.check_config_handle()
return self._config_handle.get_elementwise_op_strategy_follow()
def set_tensor_slice_align_enable(self, align_enable):
self.check_config_handle()
self._config_handle.set_tensor_slice_align_enable(align_enable)
def get_tensor_slice_align_enable(self):
self.check_config_handle()
return self._config_handle.get_tensor_slice_align_enable()
def set_tensor_slice_align_size(self, align_size):
"""
Set tensor slice align size.
Args:
align_size (int): The minimum tensor slice shape.
Raises:
ValueError: If align_size is not in [1, 1024].
"""
self.check_config_handle()
if align_size < 1 or align_size > 1024:
raise ValueError('Align_size must be in [1, 1024], but got {}'.format(align_size))
self._config_handle.set_tensor_slice_align_size(align_size)
def get_tensor_slice_align_size(self):
self.check_config_handle()
return self._config_handle.get_tensor_slice_align_size()
def reset_algo_parameters(self):
self.check_config_handle()
self._config_handle.reset_algo_parameters()
_g_algo_parameter_config = None
def _algo_parameter_config():
"""
Get the global _g_algo_parameter_config. If it is not created, create a new one.
Returns:
The global _g_algo_parameter_config.
"""
global _g_algo_parameter_config
if _g_algo_parameter_config is None:
_g_algo_parameter_config = _AlgoParameterConfig()
return _g_algo_parameter_config
set_algo_parameters_config_func_map = {
"simplify_cal": _algo_parameter_config().set_simplify_cal,
"not_fully_use_devices": _algo_parameter_config().set_not_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size}
get_algo_parameters_config_func_map = {
"simplify_cal": _algo_parameter_config().get_simplify_cal,
"not_fully_use_devices": _algo_parameter_config().get_not_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}
[docs]@args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int,
not_fully_use_devices=bool, elementwise_op_strategy_follow=bool)
def set_algo_parameters(**kwargs):
"""
Set algo parameter config.
Note:
Attribute name is needed.
Args:
simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True
tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False
tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16
not_fully_use_devices (bool): Whether generating strategies that not fully use devices. Default: False
elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its
subsequent operators. Default: False
Raises:
ValueError: If context keyword is not recognized.
"""
for key, value in kwargs.items():
if key not in set_algo_parameters_config_func_map:
raise ValueError("Set context keyword %s is not recognized!" % key)
set_func = set_algo_parameters_config_func_map[key]
set_func(value)
[docs]def get_algo_parameters(attr_key):
"""
Get algo parameter config attributes.
Note:
Return value according to the attribute value.
Args:
attr_key (str): The key of the attribute.
Raises:
ValueError: If context keyword is not recognized.
"""
if attr_key not in get_algo_parameters_config_func_map:
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
get_func = get_algo_parameters_config_func_map[attr_key]
return get_func()
[docs]def reset_algo_parameters():
"""Reset algo parameter attributes."""
_algo_parameter_config().reset_algo_parameters()