# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""boost"""
from .less_batch_normalization import LessBN
from .grad_freeze import GradientFreeze
from .base import OptimizerProcess, ParameterProcess
__all__ = ["AutoBoost"]
_boost_config_level = {
"O0": {
"less_bn": False,
"grad_freeze": False,
"adasum": False},
"O1": {
"less_bn": True,
"grad_freeze": True,
"adasum": False},
"O2": {
"less_bn": True,
"grad_freeze": True,
"adasum": True}}
[docs]class AutoBoost:
"""
Provide auto accelerating for network.
Args:
level (str): boost config level.
kwargs (any): Additional configuration parameters related to boost.
"""
def __init__(self, level, kwargs):
if level not in _boost_config_level.keys():
level = 'O0'
self.level = level
boost_config = _boost_config_level[level]
self._boost_config = boost_config
self._fn_flag = True
self._gc_flag = True
self._param_groups = 10
self._freeze_type = 1
self._freeze_p = 0.7
self._total_steps = 65536
self._gradient_groups = None
self._get_configuration(kwargs)
self._param_processer = ParameterProcess()
def _get_configuration(self, kwargs):
"""Get configuration."""
for key, val in kwargs.items():
if key not in self._boost_config_func_map.keys():
continue
self._boost_config_func_map[key](self, val)
[docs] def network_auto_process_train(self, network, optimizer):
"""Network train."""
if self._boost_config["less_bn"]:
network = LessBN(network, fn_flag=self._fn_flag)
optimizer_process = OptimizerProcess(optimizer)
group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
self._gradient_groups)
optimizer_process.origin_params = \
self._param_processer.generate_group_params(group_params, optimizer_process.origin_params)
if self._gc_flag:
optimizer_process.add_grad_centralization(network)
optimizer = optimizer_process.generate_new_optimizer()
if self._boost_config["grad_freeze"]:
freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
self._freeze_p, self._total_steps)
network, optimizer = freeze_processer.freeze_generate(network, optimizer)
if self._boost_config["adasum"]:
setattr(optimizer, "adasum", True)
return network, optimizer
[docs] def network_auto_process_eval(self, network):
"""Network eval."""
if self._boost_config["less_bn"]:
network = LessBN(network)
return network
def set_fn_flag(self, fn_flag):
self._fn_flag = fn_flag
def set_gc_flag(self, gc_flag):
self._gc_flag = gc_flag
def set_param_groups(self, param_groups):
self._param_groups = param_groups
def set_freeze_type(self, freeze_type):
self._freeze_type = freeze_type
def set_freeze_p(self, freeze_p):
self._freeze_p = freeze_p
def set_total_steps(self, total_steps):
self._total_steps = total_steps
def set_gradient_groups(self, gradient_groups):
if not isinstance(gradient_groups, (list, int)):
raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
if isinstance(gradient_groups, int):
gradient_groups = list(gradient_groups)
self._gradient_groups = gradient_groups
_boost_config_func_map = {
"fn_flag": set_fn_flag,
"gc_flag": set_gc_flag,
"param_groups": set_param_groups,
"freeze_type": set_freeze_type,
"freeze_p": set_freeze_p,
"total_steps": set_total_steps,
"gradient_groups": set_gradient_groups
}