Source code for mindelec.common.lr_scheduler

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#pylint: disable=C1801
"""lr scheduler"""
import numpy as np

from mindspore.ops import constexpr
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from ..architecture.util import check_mode


@constexpr
def _check_tensor(is_tensor):
    """
    If the input is not a tensor raise TypeError.
    """
    if not is_tensor:
        raise TypeError("The type of global_step should be Tensor")


@constexpr
def _check_dimension(shape):
    """check dimension of input"""
    if len(shape) != 0:
        raise ValueError("The shape of global_step should be `()`, but got shape {} with length {}".format(
            shape, len(shape)))


[docs]class LearningRate(LearningRateSchedule): r""" Warmup-decay learning rate. Args: learning_rate (float): positive float type number of basic learning rate. end_learning_rate (float): non-negtive float type number of end learning rate. warmup_steps (int): non-negtive int type number of warmup steps. decay_steps (int): A positive int value used to calculate decayed learning rate. power (float): A positive float value used to calculate decayed learning rate. Inputs: - **global_steps** (Tensor) - The current step number with shape :math:`()`. Returns: Tensor. The learning rate value for the current step with shape :math:`()`. Supported Platforms: ``Ascend`` Examples: >>> from mindelec.common import LearningRate >>> from mindspore.common.tensor import Tensor >>> from mindspore.common import dtype as mstype >>> lr = LearningRate(0.1, 0.001, 0, 10, 0.5) >>> print(lr(Tensor(1000, mstype.int32))) 0.001 """ def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): super(LearningRate, self).__init__() check_mode("LearningRate") _check_type(learning_rate, "learning_rate", float, thresh_hold=0.0, restrict=True) _check_type(end_learning_rate, "end_learning_rate", float, thresh_hold=0.0, restrict=False) _check_type(warmup_steps, "warmup_steps", int, thresh_hold=0, restrict=False, exclude=bool) _check_type(decay_steps, "decay_steps", int, thresh_hold=0, restrict=True, exclude=bool) _check_type(power, "power", float, thresh_hold=0.0, restrict=True) self.warmup_flag = False if warmup_steps > 0: self.warmup_flag = True self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) self.greater = P.Greater() self.one = Tensor(np.array([1.0]).astype(np.float32)) self.cast = P.Cast()
[docs] def construct(self, global_step): """get learning rate of current step""" _check_tensor(isinstance(global_step, Tensor)) _check_dimension(global_step.shape) decay_lr = self.decay_lr(global_step) if self.warmup_flag: is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) warmup_lr = self.warmup_lr(global_step) lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr else: lr = decay_lr return lr
[docs]def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power): """ generate learning rate array Args: global_step(int): current step number, non-negtive int value. lr_init(float): init learning rate, positive float value. lr_end(float): end learning rate, non-negtive float value. lr_max(float): max learning rate, positive float value. warmup_steps(int): number of warmup epochs, non-negtive int value. total_steps(int): total epoch of training, positive int value. poly_power(float): poly learning rate power, positive float value. Returns: Numpy.array, learning rate array Supported Platforms: ``Ascend`` Examples: >>> from mindelec.common import get_poly_lr >>> learning_rate = get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, 0.5) >>> print(learning_rate.shape) (9900,) """ _check_type(global_step, "global_step", int, thresh_hold=0, restrict=False, exclude=bool) _check_type(lr_init, "lr_init", float, thresh_hold=0.0, restrict=True) _check_type(lr_end, "lr_end", float, thresh_hold=0.0, restrict=False) _check_type(lr_max, "lr_max", float, thresh_hold=0.0, restrict=True) _check_type(warmup_steps, "warmup_steps", int, thresh_hold=0, restrict=False, exclude=bool) _check_type(total_steps, "total_steps", int, thresh_hold=0, restrict=True, exclude=bool) _check_type(poly_power, "poly_power", float, thresh_hold=0.0, restrict=True) lr_each_step = [] if warmup_steps != 0: inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) else: inc_each_step = 0 for i in range(total_steps): if i < warmup_steps: lr = float(lr_init) + inc_each_step * float(i) else: base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) lr = float(lr_max - lr_end) * (base ** poly_power) lr = lr + lr_end if lr < 0.0: lr = 0.0 lr_each_step.append(lr) learning_rate = np.array(lr_each_step).astype(np.float32) current_step = global_step learning_rate = learning_rate[current_step:] return learning_rate
def _check_type(param, param_name, param_type, thresh_hold=0, restrict=False, exclude=None): if (exclude and isinstance(param, exclude)) or not isinstance(param, param_type): raise TypeError("the type of {} should be {}, but got {}".format(param_name, param_type, type(param))) if restrict: if param <= thresh_hold: raise ValueError("the value of {} should be > {}, but got: {}".format(param_name, thresh_hold, param)) else: if param < thresh_hold: raise ValueError("the value of {} should be >= {}, but got: {}".format(param_name, thresh_hold, param))