# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""optimizer"""
from typing import Iterable
import numpy as np
import mindspore
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.nn.layer.container import CellList
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor, RowTensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.train.parallel_utils import ParallelMode
from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
__all__ = ['Optimizer']
[docs]class Optimizer(Cell):
"""
Base class for all optimizers.
Note:
This class defines the API to add Ops to train a model. Never use
this class directly, but instead instantiate one of its subclasses.
Different parameter groups can set different `learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning
rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`,
the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the
type of `loss_scale` input is int, it will be converted to float. Default: 1.0.
Raises:
ValueError: If the learning_rate is a Tensor, but the dimension of tensor is greater than 1.
TypeError: If the learning_rate is not any of the three types: float, Tensor, nor Iterable.
"""
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
super(Optimizer, self).__init__(auto_prefix=False)
if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters)
if not parameters:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(parameters[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(loss_scale, int):
loss_scale = float(loss_scale)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
self.loss_scale = loss_scale
weight_decay = self._preprocess_weight_decay(weight_decay)
self.dynamic_lr = False
self.assignadd = None
self.global_step = None
self.is_group = False
self.is_group_lr = False
self.is_group_params_ordered = False
learning_rate = self._preprocess_single_lr(learning_rate)
if isinstance(parameters[0], dict):
self.is_group = True
self.group_params = []
self.group_lr = []
self.group_weight_decay = []
self._init_group_params(parameters, learning_rate, weight_decay)
# The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
if self.dynamic_lr:
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
if self.is_group_lr:
if self.dynamic_lr:
self.learning_rate = CellList(self.group_lr)
else:
self.learning_rate = ParameterTuple(self.group_lr)
else:
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
if self.is_group:
self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay)
decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
self.exec_weight_decay = any(self.decay_flags)
else:
self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.exec_weight_decay = self.weight_decay > 0
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.param_length = len(self.parameters)
self.map_ = C.Map()
use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
[docs] def decay_weight(self, gradients):
"""
Weight decay.
An approach to reduce the overfitting of a deep learning neural network model.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
`self.parameters`.
Returns:
tuple[Tensor], The gradients after weight decay.
"""
if self.exec_weight_decay:
params = self.parameters
if self.is_group:
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
params, gradients)
else:
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
return gradients
[docs] def scale_grad(self, gradients):
"""
Loss scale for mixed precision.
An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
network.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
`self.parameters`.
Returns:
tuple[Tensor], The gradients after loss scale.
"""
if self.reciprocal_scale != 1.0:
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
return gradients
def _preprocess_weight_decay(self, weight_decay):
"""Check weight decay, and convert int to float."""
if isinstance(weight_decay, (float, int)):
weight_decay = float(weight_decay)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
return weight_decay
raise TypeError("Weight decay should be int or float.")
def _preprocess_single_lr(self, learning_rate):
"""Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
if isinstance(learning_rate, (float, int)):
learning_rate = float(learning_rate)
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
return learning_rate
self.dynamic_lr = True
if isinstance(learning_rate, Iterable):
return Tensor(np.array(list(learning_rate)).astype(np.float32))
if isinstance(learning_rate, Tensor):
if learning_rate.dim() > 1:
raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
f"but got {learning_rate.dim()}.")
if learning_rate.dim() == 1 and learning_rate.size() < 2:
logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
"of elements in the tensor passed is greater than 1.")
return learning_rate
if isinstance(learning_rate, LearningRateSchedule):
return learning_rate
raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.")
def _build_single_lr(self, learning_rate, name):
"""Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule."""
if isinstance(learning_rate, float):
learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name)
if self.is_group_lr and self.dynamic_lr:
learning_rate = _ConvertToCell(learning_rate)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
learning_rate = Parameter(learning_rate, name)
if self.is_group_lr and self.dynamic_lr:
learning_rate = _ConvertToCell(learning_rate)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
return _IteratorLearningRate(learning_rate, name)
return learning_rate
def _check_group_params(self, parameters):
"""Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
for group_param in parameters:
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
if invalid_key:
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
if 'order_params' in group_param.keys():
if len(group_param.keys()) > 1:
raise ValueError("The order params dict in group parameters should "
"only include the 'order_params' key.")
if not isinstance(group_param['order_params'], Iterable):
raise TypeError("The value of 'order_params' should be an Iterable type.")
continue
if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")
for param in group_param['params']:
if not isinstance(param, Parameter):
raise TypeError("The group param should be an iterator of Parameter type.")
def _parse_group_params(self, parameters, learning_rate):
"""Parse group params."""
self._check_group_params(parameters)
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
tensor_lr_length = learning_rate.size()
else:
tensor_lr_length = 0
for group_param in parameters:
if 'order_params' in group_param.keys():
if len(group_param.keys()) > 1:
raise ValueError("The order params dict in group parameters should "
"only include the 'order_params' key.")
if not isinstance(group_param['order_params'], Iterable):
raise TypeError("The value of 'order_params' should be an Iterable type.")
self.is_group_params_ordered = True
continue
if 'lr' in group_param.keys():
self.is_group_lr = True
group_lr = self._preprocess_single_lr(group_param['lr'])
if isinstance(group_lr, Tensor) and group_lr.dim() == 1:
group_lr_length = group_lr.size()
if tensor_lr_length == 0:
tensor_lr_length = group_lr_length
elif group_lr_length != tensor_lr_length:
raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
def _init_group_params(self, parameters, learning_rate, weight_decay):
"""Init learning rate or weight decay in group params."""
self._parse_group_params(parameters, learning_rate)
default_lr = self._build_single_lr(learning_rate, 'learning_rate')
params_store = []
for group_num, group_param in enumerate(parameters):
if 'order_params' in group_param.keys():
ordered_parameters = group_param['order_params']
continue
self.group_params += group_param['params']
if 'lr' in group_param.keys():
lr_param_name = 'learning_rate_group_' + str(group_num)
lr = self._preprocess_single_lr(group_param['lr'])
lr = self._build_single_lr(lr, lr_param_name)
else:
lr = default_lr
if 'weight_decay' in group_param.keys():
cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
weight_decay_ = cur_weight_decay * self.loss_scale
else:
weight_decay_ = weight_decay * self.loss_scale
for key in group_param.keys():
if key not in ('params', 'lr', 'weight_decay'):
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
for param in group_param['params']:
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
if param.name in params_store:
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
params_store.append(param.name)
self.group_lr.append(lr)
self.group_weight_decay.append(weight_decay_)
if self.is_group_params_ordered:
self._order_and_adjust_group_params(ordered_parameters)
def _order_and_adjust_group_params(self, ordered_parameters):
"""
Order group parameter, learning rate and weight decay in group params.
"""
params_length = len(self.group_params)
if len(ordered_parameters) != len(self.group_params):
raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
ordered_params = [None] * params_length
ordered_learning_rate = [None] * params_length
ordered_weight_decay = [None] * params_length
params_name = [param.name for param in ordered_parameters]
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
index = params_name.index(param.name)
ordered_params[index] = param
ordered_learning_rate[index] = lr
ordered_weight_decay[index] = wd
self.group_params = ordered_params
self.group_lr = ordered_learning_rate
self.group_weight_decay = ordered_weight_decay
[docs] def get_lr(self):
"""
Get the learning rate of current step.
Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.learning_rate:
current_dynamic_lr = learning_rate(self.global_step)
lr += (current_dynamic_lr,)
else:
lr = self.learning_rate(self.global_step)
F.control_depend(lr, self.assignadd(self.global_step, 1))
return lr
[docs] def get_lr_parameter(self, param):
"""
Get the learning rate of parameter.
Args:
param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
Returns:
Parameter, single `Parameter` or `list[Parameter]` according to the input type.
"""
def get_lr_value(learning_rate):
if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)):
return learning_rate.learning_rate
return learning_rate
if isinstance(param, Parameter):
param_list = [param]
elif isinstance(param, list):
param_list = param
else:
raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
lr = []
ids = [id(p) for p in self.parameters]
for p in param_list:
validator.check_value_type("parameter", p, [Parameter], self.cls_name)
if id(p) not in ids:
raise ValueError(f"The parameter {p.name} is not in optimizer.")
if self.is_group_lr:
index = ids.index(id(p))
lr.append(get_lr_value(self.learning_rate[index]))
else:
lr.append(get_lr_value(self.learning_rate))
return lr if isinstance(param, list) else lr[0]
def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.
Returns:
tuple, the group id tuple of parameters.
"""
rank_list = ()
count = 0
for _ in range(self.param_length):
rank_list = rank_list + (count,)
count = count + 1
if count == self.dev_num:
count = 0
return rank_list
[docs] def broadcast_params(self, optim_result):
"""
Apply Broadcast operations in the sequential order of parameter groups.
Returns:
bool, the status flag.
"""
param_group = []
key_group = []
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
key_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = []
for root in range(self.dev_num):
ops = P.Broadcast(root)
next_params = ops(param_group[root])
new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i])
status = F.control_depend(optim_result, new_param_group[0][0])
for i in range(self.dev_num - 1):
status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status)
return status
def construct(self, *hyper_params):
raise NotImplementedError
op_add = P.AddN()
op_gather = P.GatherV2()
_apply_decay = C.MultitypeFuncGraph("apply_decay")
@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
indices = gradient.indices
values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values))
shape = gradient.dense_shape
return RowTensor(indices, values, shape)
return gradient
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient
_grad_scale = C.MultitypeFuncGraph("grad_scale")
@_grad_scale.register("Number", "Tensor")
def tensor_grad_scale(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return grad * scale
@_grad_scale.register("Number", "RowTensor")
def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)
class _ConvertToCell(LearningRateSchedule):
"""Inner api, convert learning rate of scalar to LearningRateSchedule."""
def __init__(self, learning_rate):
super(_ConvertToCell, self).__init__()
if not isinstance(learning_rate, Parameter):
raise TypeError('Learning rate must be Parameter.')
self.learning_rate = learning_rate
def construct(self, global_step):
return self.learning_rate + 1.0 - 1.0
class _IteratorLearningRate(LearningRateSchedule):
"""Inner api, convert learning rate of Tensor(list) to LearningRateSchedule."""
def __init__(self, learning_rate, name):
super(_IteratorLearningRate, self).__init__()
if isinstance(learning_rate, Tensor):
if learning_rate.dim() != 1:
raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
f"but got {learning_rate.dim()}.")
else:
raise TypeError("Learning rate should be Tensor.")
self.learning_rate = Parameter(learning_rate, name)
self.gather = P.GatherV2()
def construct(self, global_step):
return self.gather(self.learning_rate, global_step, 0)