Source code for mindspore.nn.optim.optimizer

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""optimizer"""
from typing import Iterable

import numpy as np

import mindspore
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.nn.layer.container import CellList
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor, RowTensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.train.parallel_utils import ParallelMode
from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

__all__ = ['Optimizer']


[docs]class Optimizer(Cell): """ Base class for all optimizers. Note: This class defines the API to add Ops to train a model. Never use this class directly, but instead instantiate one of its subclasses. Different parameter groups can set different `learning_rate` and `weight_decay`. When separating parameter groups, the weight decay in each group will be applied on the parameters if the weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. Args: learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, use dynamic learning rate, the i-th learning rate will be calculated during the process of training according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero dimension, use fixed learning rate. Other cases are not supported. The float learning rate should be equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed. - params: Required. The value should be a list of `Parameter`. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. If not, the `learning_rate` in the API will be used. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the `weight_decay` in the API will be used. - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the type of `loss_scale` input is int, it will be converted to float. Default: 1.0. Raises: ValueError: If the learning_rate is a Tensor, but the dimension of tensor is greater than 1. TypeError: If the learning_rate is not any of the three types: float, Tensor, nor Iterable. """ def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): super(Optimizer, self).__init__(auto_prefix=False) if parameters is not None and not isinstance(parameters, list): parameters = list(parameters) if not parameters: raise ValueError("Optimizer got an empty parameter list.") if not isinstance(parameters[0], (dict, Parameter)): raise TypeError("Only a list of Parameter or dict can be supported.") if isinstance(loss_scale, int): loss_scale = float(loss_scale) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) self.loss_scale = loss_scale weight_decay = self._preprocess_weight_decay(weight_decay) self.dynamic_lr = False self.assignadd = None self.global_step = None self.is_group = False self.is_group_lr = False self.is_group_params_ordered = False learning_rate = self._preprocess_single_lr(learning_rate) if isinstance(parameters[0], dict): self.is_group = True self.group_params = [] self.group_lr = [] self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params if self.dynamic_lr: self.assignadd = P.AssignAdd() self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') if self.is_group_lr: if self.dynamic_lr: self.learning_rate = CellList(self.group_lr) else: self.learning_rate = ParameterTuple(self.group_lr) else: self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') if self.is_group: self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) self.exec_weight_decay = any(self.decay_flags) else: self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.exec_weight_decay = self.weight_decay > 0 ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale self.param_length = len(self.parameters) self.map_ = C.Map() use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer") self.use_parallel = use_parallel if use_parallel: if self.cls_name not in ["Lamb", "AdamWeightDecay"]: raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) if _get_parallel_mode() != ParallelMode.DATA_PARALLEL: raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format (_get_parallel_mode())) self.dev_num = _get_device_num() if self.dev_num > self.param_length: raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" " less than the number of devices {}".format(self.param_length, self.dev_num)) self.param_rank = self._get_parameter_group_id() self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) self.param_names = [] for param in self.parameters: self.param_names.append(param.name) else: self.optim_filter = (True,) * self.param_length
[docs] def decay_weight(self, gradients): """ Weight decay. An approach to reduce the overfitting of a deep learning neural network model. Args: gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as `self.parameters`. Returns: tuple[Tensor], The gradients after weight decay. """ if self.exec_weight_decay: params = self.parameters if self.is_group: gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, params, gradients) else: gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, params, gradients) return gradients
[docs] def scale_grad(self, gradients): """ Loss scale for mixed precision. An approach of mixed precision training to improve the speed and energy efficiency of training deep neural network. Args: gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as `self.parameters`. Returns: tuple[Tensor], The gradients after loss scale. """ if self.reciprocal_scale != 1.0: gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients) return gradients
def _preprocess_weight_decay(self, weight_decay): """Check weight decay, and convert int to float.""" if isinstance(weight_decay, (float, int)): weight_decay = float(weight_decay) validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) return weight_decay raise TypeError("Weight decay should be int or float.") def _preprocess_single_lr(self, learning_rate): """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" if isinstance(learning_rate, (float, int)): learning_rate = float(learning_rate) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) return learning_rate if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: return learning_rate self.dynamic_lr = True if isinstance(learning_rate, Iterable): return Tensor(np.array(list(learning_rate)).astype(np.float32)) if isinstance(learning_rate, Tensor): if learning_rate.dim() > 1: raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1," f"but got {learning_rate.dim()}.") if learning_rate.dim() == 1 and learning_rate.size() < 2: logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number" "of elements in the tensor passed is greater than 1.") return learning_rate if isinstance(learning_rate, LearningRateSchedule): return learning_rate raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.") def _build_single_lr(self, learning_rate, name): """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule.""" if isinstance(learning_rate, float): learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name) if self.is_group_lr and self.dynamic_lr: learning_rate = _ConvertToCell(learning_rate) return learning_rate if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: learning_rate = Parameter(learning_rate, name) if self.is_group_lr and self.dynamic_lr: learning_rate = _ConvertToCell(learning_rate) return learning_rate if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: return _IteratorLearningRate(learning_rate, name) return learning_rate def _check_group_params(self, parameters): """Check group params.""" parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] for group_param in parameters: invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) if invalid_key: raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') if 'order_params' in group_param.keys(): if len(group_param.keys()) > 1: raise ValueError("The order params dict in group parameters should " "only include the 'order_params' key.") if not isinstance(group_param['order_params'], Iterable): raise TypeError("The value of 'order_params' should be an Iterable type.") continue if not group_param['params']: raise ValueError("Optimizer got an empty group parameter list.") for param in group_param['params']: if not isinstance(param, Parameter): raise TypeError("The group param should be an iterator of Parameter type.") def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" self._check_group_params(parameters) if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: tensor_lr_length = learning_rate.size() else: tensor_lr_length = 0 for group_param in parameters: if 'order_params' in group_param.keys(): if len(group_param.keys()) > 1: raise ValueError("The order params dict in group parameters should " "only include the 'order_params' key.") if not isinstance(group_param['order_params'], Iterable): raise TypeError("The value of 'order_params' should be an Iterable type.") self.is_group_params_ordered = True continue if 'lr' in group_param.keys(): self.is_group_lr = True group_lr = self._preprocess_single_lr(group_param['lr']) if isinstance(group_lr, Tensor) and group_lr.dim() == 1: group_lr_length = group_lr.size() if tensor_lr_length == 0: tensor_lr_length = group_lr_length elif group_lr_length != tensor_lr_length: raise ValueError("The Tensor type dynamic learning rate in group should be the same size.") def _init_group_params(self, parameters, learning_rate, weight_decay): """Init learning rate or weight decay in group params.""" self._parse_group_params(parameters, learning_rate) default_lr = self._build_single_lr(learning_rate, 'learning_rate') params_store = [] for group_num, group_param in enumerate(parameters): if 'order_params' in group_param.keys(): ordered_parameters = group_param['order_params'] continue self.group_params += group_param['params'] if 'lr' in group_param.keys(): lr_param_name = 'learning_rate_group_' + str(group_num) lr = self._preprocess_single_lr(group_param['lr']) lr = self._build_single_lr(lr, lr_param_name) else: lr = default_lr if 'weight_decay' in group_param.keys(): cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay']) weight_decay_ = cur_weight_decay * self.loss_scale else: weight_decay_ = weight_decay * self.loss_scale for key in group_param.keys(): if key not in ('params', 'lr', 'weight_decay'): logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") for param in group_param['params']: validator.check_value_type("parameter", param, [Parameter], self.cls_name) if param.name in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") params_store.append(param.name) self.group_lr.append(lr) self.group_weight_decay.append(weight_decay_) if self.is_group_params_ordered: self._order_and_adjust_group_params(ordered_parameters) def _order_and_adjust_group_params(self, ordered_parameters): """ Order group parameter, learning rate and weight decay in group params. """ params_length = len(self.group_params) if len(ordered_parameters) != len(self.group_params): raise ValueError(f"The value of 'order_params' should be same with all group parameters.") ordered_params = [None] * params_length ordered_learning_rate = [None] * params_length ordered_weight_decay = [None] * params_length params_name = [param.name for param in ordered_parameters] for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay): index = params_name.index(param.name) ordered_params[index] = param ordered_learning_rate[index] = lr ordered_weight_decay[index] = wd self.group_params = ordered_params self.group_lr = ordered_learning_rate self.group_weight_decay = ordered_weight_decay
[docs] def get_lr(self): """ Get the learning rate of current step. Returns: float, the learning rate of current step. """ lr = self.learning_rate if self.dynamic_lr: if self.is_group_lr: lr = () for learning_rate in self.learning_rate: current_dynamic_lr = learning_rate(self.global_step) lr += (current_dynamic_lr,) else: lr = self.learning_rate(self.global_step) F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr
[docs] def get_lr_parameter(self, param): """ Get the learning rate of parameter. Args: param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`. Returns: Parameter, single `Parameter` or `list[Parameter]` according to the input type. """ def get_lr_value(learning_rate): if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)): return learning_rate.learning_rate return learning_rate if isinstance(param, Parameter): param_list = [param] elif isinstance(param, list): param_list = param else: raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") lr = [] ids = [id(p) for p in self.parameters] for p in param_list: validator.check_value_type("parameter", p, [Parameter], self.cls_name) if id(p) not in ids: raise ValueError(f"The parameter {p.name} is not in optimizer.") if self.is_group_lr: index = ids.index(id(p)) lr.append(get_lr_value(self.learning_rate[index])) else: lr.append(get_lr_value(self.learning_rate)) return lr if isinstance(param, list) else lr[0]
def _get_parameter_group_id(self): """ Get the parameter partition group id, which is less than the number of devices. Returns: tuple, the group id tuple of parameters. """ rank_list = () count = 0 for _ in range(self.param_length): rank_list = rank_list + (count,) count = count + 1 if count == self.dev_num: count = 0 return rank_list
[docs] def broadcast_params(self, optim_result): """ Apply Broadcast operations in the sequential order of parameter groups. Returns: bool, the status flag. """ param_group = [] key_group = [] for _ in range(self.dev_num): param_group.append(F.make_tuple()) key_group.append(F.make_tuple()) for i in range(self.param_length): param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],) key = P.MakeRefKey(self.param_names[i])() key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) new_param_group = [] for root in range(self.dev_num): ops = P.Broadcast(root) next_params = ops(param_group[root]) new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) status = F.control_depend(optim_result, new_param_group[0][0]) for i in range(self.dev_num - 1): status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status) return status
def construct(self, *hyper_params): raise NotImplementedError
op_add = P.AddN() op_gather = P.GatherV2() _apply_decay = C.MultitypeFuncGraph("apply_decay") @_apply_decay.register("Number", "Bool", "Tensor", "RowTensor") def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: indices = gradient.indices values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values)) shape = gradient.dense_shape return RowTensor(indices, values, shape) return gradient @_apply_decay.register("Number", "Bool", "Tensor", "Tensor") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: return op_add((weight * weight_decay, gradient)) return gradient _grad_scale = C.MultitypeFuncGraph("grad_scale") @_grad_scale.register("Number", "Tensor") def tensor_grad_scale(scale, grad): """Get grad with scale.""" if scale == 1.0: return grad return grad * scale @_grad_scale.register("Number", "RowTensor") def tensor_grad_scale_with_sparse(scale, grad): """Get grad with scale.""" if scale == 1.0: return grad return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) class _ConvertToCell(LearningRateSchedule): """Inner api, convert learning rate of scalar to LearningRateSchedule.""" def __init__(self, learning_rate): super(_ConvertToCell, self).__init__() if not isinstance(learning_rate, Parameter): raise TypeError('Learning rate must be Parameter.') self.learning_rate = learning_rate def construct(self, global_step): return self.learning_rate + 1.0 - 1.0 class _IteratorLearningRate(LearningRateSchedule): """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule.""" def __init__(self, learning_rate, name): super(_IteratorLearningRate, self).__init__() if isinstance(learning_rate, Tensor): if learning_rate.dim() != 1: raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1," f"but got {learning_rate.dim()}.") else: raise TypeError("Learning rate should be Tensor.") self.learning_rate = Parameter(learning_rate, name) self.gather = P.GatherV2() def construct(self, global_step): return self.gather(self.learning_rate, global_step, 0)