Source code for mindspore.ops.operations.inner_ops

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""inner_ops"""

import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
from .. import signature as sig


[文档]class ScalarCast(PrimitiveWithInfer): """ Casts the input scalar to another type. Inputs: - **input_x** (scalar) - The input scalar. Only constant value is allowed. - **input_y** (mindspore.dtype) - The type to be cast. Only constant value is allowed. Outputs: Scalar. The type is the same as the python type corresponding to `input_y`. Raises: TypeError: If neither `input_x` nor `input_y` is a constant value. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> scalar_cast = ops.ScalarCast() >>> output = scalar_cast(255.0, mindspore.int32) >>> print(output) 255 """ @prim_attr_register def __init__(self): pass def __infer__(self, x, t): validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name) value, to = x['value'], t['value'] if value is not None: validator.check_value_type("value", value, [numbers.Number, bool], self.name) if isinstance(to, type(tensor)): to = to.element_type() np_type = dtype_to_pytype(to) value = np_type(value) out = {'shape': x['shape'], 'dtype': t['value'], 'value': value} return out
[文档]class Randperm(PrimitiveWithInfer): """ Generates n random samples from 0 to n-1 without repeating. If `max_length` > n, the last `max_length-n` elements will be filled with `pad`. Args: max_length (int): Number of items expected to get and the number must be greater than 0. Default: 1. pad (int): The pad value to be filled. Default: -1. dtype (mindspore.dtype): The type of output. Default: mindspore.int32. Inputs: - **n** (Tensor[int32]) - The input tensor with shape: (1,) and the number must be in [0, `max_length`]. Outputs: - **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`. Raises: TypeError: If neither `max_length` nor `pad` is an int. TypeError: If `n` is not a Tensor. TypeError: If `n` has non-Int elements. TypeError: If `n` has negative elements. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> # The result of every execution is different because this operator will generate n random samples. >>> randperm = ops.Randperm(max_length=30, pad=-1) >>> n = Tensor([20], dtype=mindspore.int32) >>> output = randperm(n) >>> print(output) [15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 1 12 3 7 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1] """ @prim_attr_register def __init__(self, max_length=1, pad=-1, dtype=mstype.int32): """Initialize Randperm""" validator.check_value_type("pad", pad, [int], self.name) validator.check_value_type("max_length", max_length, [int], self.name) validator.check_int(max_length, 1, Rel.GE, "max_length", self.name) self.dtype = dtype self.max_length = max_length self.init_prim_io_names(inputs=[], outputs=['output']) def infer_shape(self, n_shape): validator.check_int(len(n_shape), 1, Rel.EQ, "rank_of_n", self.name) validator.check_int(n_shape[0], 1, Rel.EQ, "length_of_n", self.name) return [self.max_length] def infer_dtype(self, n_type): validator.check_type_name("n_type", n_type, mstype.int32, self.name) valid_values = (mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64) validator.check_type_name("dtype", self.dtype, valid_values, self.name) return self.dtype
[文档]class NoRepeatNGram(PrimitiveWithInfer): """ Updates log_probs with repeat n-grams. During beam search, if consecutive `ngram_size` words exist in the generated word sequence, the consecutive `ngram_size` words will be avoided during subsequent prediction. For example, when `ngram_size` is 3, the generated word sequence is [1, 2, 3, 2, 3], the next predicted word will not be 2 and the value of `log_probs` will be replaced with -FLOAT_MAX. Because 3 consecutive words [2, 3, 2] do not appear twice in the word sequence. Args: ngram_size (int): Size of n-grams, must be greater than 0. Default: 1. Inputs: - **state_seq** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, m). - **log_probs** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, vocab_size). The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated. Outputs: - **log_probs** (Tensor) - The output Tensor with same shape and type as original `log_probs`. Raises: TypeError: If `ngram_size` is not an int. TypeError: If neither `state_seq` nor `log_probs` is a Tensor. Supported Platforms: ``Ascend`` Examples: >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], ... [9, 3, 9, 5, 4, 1, 5]], ... [[4, 8, 6, 4, 5, 6, 4], ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) >>> output = no_repeat_ngram(state_seq, log_probs) >>> print(output) [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 2.0000000e-01 6.9999999e-01] [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 6.9999999e-01 1.0000000e-01]] [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 8.0000001e-01 6.0000002e-01] [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 -3.4028235e+38 6.9999999e-01]]] """ @prim_attr_register def __init__(self, ngram_size=1): """NoRepeatNGram Randperm""" validator.check_value_type("ngram_size", ngram_size, [int], self.name) validator.check_int(ngram_size, 1, Rel.GE, "ngram_size", self.name) self.ngram_size = ngram_size self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs']) def infer_shape(self, seq_shape, log_shape): validator.check_int(len(seq_shape), 3, Rel.EQ, "rank of state_seq", self.name) validator.check_int(len(log_shape), 3, Rel.EQ, "rank of log_probs", self.name) validator.check("state_seq shape[0]", seq_shape[0], "log_probs shape[0]", log_shape[0], Rel.EQ, self.name) validator.check("state_seq shape[1]", seq_shape[1], "log_probs shape[1]", log_shape[1], Rel.EQ, self.name) validator.check("ngram_size", self.ngram_size, "state_seq shape[2] + 1", seq_shape[2] + 1, Rel.LE, self.name) return log_shape def infer_dtype(self, seq_type, log_type): validator.check_type_name("seq_type", seq_type, mstype.int32, self.name) valid_values = (mstype.float16, mstype.float32, mstype.float64) validator.check_type_name("log_type", log_type, valid_values, self.name) return log_type
class LambApplyOptimizerAssign(PrimitiveWithInfer): r""" Updates gradients by LAMB optimizer algorithm. Get the compute ratio. The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_. The updating formulas are as follows, .. math:: \begin{array}{ll} \\ m = \beta_1 * m + (1 - \beta_1) * g \\ v = \beta_2 * v + (1 - \beta_2) * g * g \\ m = \frac{m}{1 - \beta_1^t} \\ v = \frac{v}{1 - \beta_2^t} \\ r = \frac{m}{\sqrt{v} + \epsilon} \\ w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) \end{array} :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents `epsilon`. Inputs: - **gradient** (Tensor) - Gradient of parameters, float32/float16. - **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`. - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`. - **var** (Tensor) - Weights to be updated, has the same type as `gradient`. - **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16. - **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`. - **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`. - **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`. - **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`. - **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`. - **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`. - **decay_flag** (Tensor) -Specify whether param update with weight decay, has the same type as `beta1`. - **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`. Outputs: Tensor, the compute ratio r. - **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`. - **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace, has the same type as `gradient`. - **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace, has the same type as `gradient`. Supported Platforms: ``Ascend`` """ @prim_attr_register def __init__(self): """Initialize LambApplyOptimizerAssign""" self.add_prim_attr('side_effect_mem', True) def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape, beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return m_shape, v_shape, m_shape def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype, beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype): args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype, "eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype, "weight_decay": weight_decay_dtype} validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) return m_dtype, v_dtype, v_dtype class LambApplyWeightAssign(PrimitiveWithInfer): r""" Updates gradients by LAMB optimizer algorithm. The weight update part. The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_. The updating formulas are as follows, .. math:: \begin{array}{ll} \\ m = \beta_1 * m + (1 - \beta_1) * g \\ v = \beta_2 * v + (1 - \beta_2) * g * g \\ m = \frac{m}{1 - \beta_1^t} \\ v = \frac{v}{1 - \beta_2^t} \\ r = \frac{m}{\sqrt{v} + \epsilon} \\ w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) \end{array} :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents `epsilon`. Inputs: - **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16. - **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`. - **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16. - **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16. - **var** (Tensor) - Weights to be updated, the same shape and type as `update`. Outputs: - **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs. Supported Platforms: ``Ascend`` """ @prim_attr_register def __init__(self): """Initialize LambApplyWeightAssign""" self.add_prim_attr('side_effect_mem', True) def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape): validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype): args = {"var": var_dtype, "update": update_dtype} validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype} validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) return var_dtype class MakeRefKey(Primitive): """ Makes a RefKey instance by string. RefKey stores the name of Parameter, can be passed through the functions, and used for Assign target. Args: tag (str): Parameter name to make the RefKey. Inputs: No inputs. Outputs: RefKeyType, made from the Parameter name. Raises: TypeError: If `tag` is not a str. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import numpy as np >>> from mindspore import Parameter, Tensor >>> from mindspore import dtype as mstype >>> import mindspore.ops as ops >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.y = Parameter(Tensor(np.ones([2, 3]), mstype.int32), name="y") ... self.make_ref_key = ops.MakeRefKey("y") ... ... def construct(self, x): ... key = self.make_ref_key() ... ref = ops.make_ref(key, x, self.y) ... return ref * x ... >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.int32) >>> net = Net() >>> output = net(x) >>> print(output) [[ 1 4 9] [16 25 36]] """ @prim_attr_register def __init__(self, tag): validator.check_value_type('tag', tag, (str,), self.name) def __call__(self): pass class FusedWeightScaleApplyMomentum(PrimitiveWithInfer): """ Optimizer that implements the Momentum algorithm with weight decay and loss scale. Refer to the paper `On the importance of initialization and momentum in deep learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details. Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage. Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules to make the data types consistent. If they have different data types, the lower priority data type will be converted to relatively highest priority data type. Data type conversion of Parameter is not supported. RuntimeError exception will be thrown. Inputs: - **weight_decay** (Tensor) - The weight decay value, must be a scalar tensor with float data type. Default: 0.0. - **loss_scale** (Tensor) - The loss scale value, must be a scalar tensor with float data type. Default: 1.0. - **variable** (Parameter) - Weights to be updated. data type must be float. - **accumulation** (Parameter) - Accumulated gradient value by moment weight. Has the same data type with `variable`. - **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float number or a scalar tensor with float data type. - **gradient** (Tensor) - Gradient, has the same data type as `variable`. - **momentum** (Union[Number, Tensor]) - Momentum, must be a float number or a scalar tensor with float data type. Outputs: Tensor, parameters to be updated. Supported Platforms: ``GPU`` Examples: Please refer to the usage in :class:`mindspore.nn.Momentum`, and add weight_decay and loss_scale as inputs. """ __mindspore_signature__ = ( sig.make_sig('weight_decay', dtype=sig.sig_dtype.T3), sig.make_sig('loss_scale', dtype=sig.sig_dtype.T3), sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1), sig.make_sig('gradient', dtype=sig.sig_dtype.T), sig.make_sig('momentum', dtype=sig.sig_dtype.T2) ) @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['weight_decay', 'loss_scale', 'variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], outputs=['output']) def infer_shape(self, d_shape, s_shape, v_shape, a_shape, l_shape, g_shape, m_shape): return v_shape def infer_dtype(self, d_dtype, s_dtype, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): valid_dtypes = [mstype.float16, mstype.float32] if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name) validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name) return v_dtype class FusedCastAdamWeightDecay(PrimitiveWithInfer): r""" Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. This operator incorporates type conversion when parameters are initialized with dtype of float16. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_. The AdamWeightDecay variant was proposed in `Decoupled Weight Decay Regularization <https://arxiv.org/abs/1711.05101>`_. The updating formulas are as follows, .. math:: \begin{array}{ll} \\ m = \beta_1 * m + (1 - \beta_1) * g \\ v = \beta_2 * v + (1 - \beta_2) * g * g \\ update = \frac{m}{\sqrt{v} + \epsilon} \\ update = \begin{cases} update + weight\_decay * w & \text{ if } weight\_decay > 0 \\ update & \text{ otherwise } \end{cases} \\ w = w - lr * update \end{array} :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents `gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, :math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`, :math:`\epsilon` represents `epsilon`. Args: use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. If true, updates of the var, m, and v tensors will be protected by a lock. If false, the result is unpredictable. Default: False. Inputs: - **var** (Tensor) - Weights to be updated with the type float16 or float32. - **m** (Tensor) - The 1st moment vector in the updating formula with the type float32. - **v** (Tensor) - the 2nd moment vector in the updating formula with the type float32. - **lr** (float) - :math:`lr` in the updating formula. - **beta1** (float) - The exponential decay rate for the 1st moment estimations. - **beta2** (float) - The exponential decay rate for the 2nd moment estimations. - **epsilon** (float) - Term added to the denominator to improve numerical stability. - **decay** (float) - The weight decay value, must be a scalar tensor with float data type. - **gradient** (Tensor) - Gradient, has the type float16. Outputs: Tuple of 3 Tensor, the updated parameters. - **var** (Tensor) - The same shape and data type as `var`. - **m** (Tensor) - The same shape and data type as `m`. - **v** (Tensor) - The same shape and data type as `v`. Supported Platforms: ``CPU`` Examples: >>> import numpy as np >>> import mindspore.context as context >>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> from mindspore import Tensor, Parameter >>> from mindspore import dtype as mstype >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.opt = ops.FusedCastAdamWeightDecay() ... self.var = Parameter(Tensor(np.ones([2, 2]), mstype.float16), name="var") ... self.m = Parameter(Tensor(np.ones([2, 2]), mstype.float32), name="m") ... self.v = Parameter(Tensor(np.ones([2, 2]), mstype.float32), name="v") ... def construct(self, lr, beta1, beta2, epsilon, decay, grad): ... out = self.opt(self.var, self.m, self.v, lr, beta1, beta2, epsilon, decay, grad) ... return out >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") >>> net = Net() >>> gradient = Tensor(np.ones([2, 2]), mstype.float16) >>> output = net(0.001, 0.9, 0.999, 1e-8, 0.0, gradient) """ @prim_attr_register def __init__(self, use_locking=False): self.add_prim_attr('side_effect_mem', True) validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, decay_shape, grad_shape, global_norm): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, decay_dtype, grad_dtype, global_norm): args = {"m": m_dtype, "v": v_dtype} validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name) args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay_dtype} validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) return var_dtype, m_dtype, v_dtype class FusedAdaFactor(PrimitiveWithInfer): r""" Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm. The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost <https://arxiv.org/abs/1804.04235>`_. .. warning:: This is an experimental prototype that is subject to change and/or deletion. Adafactor for weight vector are as follows, .. math:: \begin{array}{l} \\ \alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\ G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\ \hat{V}_{t}=\hat{\beta}_{2} \hat{V}_{t-1}+\left(1-\hat{\beta}_{2_{t}}\right)\left(G_{t}^{2}+ \\ \epsilon_{1} 1_{n}\right) \\ U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\ \hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\ X_{t}=X_{t-1}-\alpha_{t} \hat{U}_{t} \end{array} Adafactor for weight matrices are as follows, .. math:: \begin{array}{l} \\ \alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\ G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\ R_{t}=\hat{\beta}_{2 t} R_{t-1}+\left(1-\hat{\beta}_{2 t}\right)\left(G_{t}^{2}+ \\ \epsilon_{1} 1_{n} 1_{m}^{\top}\right) 1_{m} \\ C_{t}=\hat{\beta}_{2 t} C_{t-1}+\left(1-\hat{\beta}_{2 t}\right) 1_{n}^{\top}\left(G_{t}^{2}+ \\ \epsilon_{1} 1_{n} 1_{m}^{\top}\right) \\ \hat{V}_{t}=R_{t} C_{t} / 1_{n}^{\top} R_{t} \\ U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\ \hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\ X_{t}=X_{t-1}-\alpha_{t} U_{t} \end{array} Where RMS is: .. math:: \operatorname{RMS}\left(U_{t}\right)=\operatorname{RMS}_{x \in X}\left(u_{x t}\right)= \\ \sqrt{\operatorname{Mean}_{x \in X}\left(\frac{\left(g_{x t}\right)^{2}}{\hat{v}_{x t}}\right)} :math:`x` is each individual parameter, :math:`t` is assumed to be the current number of steps, :math:`a_{t}` is the learning rate, :math:`f(X)` is the loss function, :math:`\epsilon1` and :math:`\epsilon2` is a small positive number to prevent errors, :math:`d` is the clipping threshold, :math:`\beta_{2}` is the moment decay, :math:`\rho` is the relative step size, :math:`R` is the running averages of the row sums of the squared gradient, :math:`C` is the running averages of the column sums of the squared gradient. Args: enable_weight_decay (bool): If True, enable weight decay. default: False enable_first_moment (bool): If True, enable first moment. default: False enable_scale_parameter (bool): If True, enable scale learning rate using parameter. default: False Inputs: - **epsilon** (Tensor) - input epsilon pair. - **clip_threshold** (float) - The threshold of root mean square of final gradient update. - **beta1** (float) - The exponential decay rate for the 1nd moment estimations. - **beta2** (float) - The exponential decay rate for the 2nd moment estimations. - **weight_decay** (float) - The weight decay value, must be a scalar tensor with float data type. - **learning_rate** (float) - The learning rate value. - **gradient** (Tensor) - Gradient. - **param** (Tensor) - Weights to be updated. - **exp_avg** (Tensor) - The exponential moving average of 1st moment optimizer state. - **exp_avg_sq_row** (Tensor) - The exponential moving average of square of gradient square row factor. - **exp_avg_sq_col** (Tensor) - The exponential moving average of square of gradient square col factor. - **exp_avg_sq** (Tensor) - The exponential moving average of square of gradient square. Outputs: - **dummy_param** (Tensor) - The same shape and data type as `param`. Supported Platforms: ``CPU`` Examples: >>> import numpy as np >>> import mindspore.context as context >>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> from mindspore import Tensor, Parameter >>> from mindspore import dtype as mstype >>> param_shape = [2, 3, 2] >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.opt = ops.FusedAdaFactor() ... self.param = Parameter(Tensor(np.ones(param_shape), mstype.float32), name="param") ... self.exp_avg = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg") ... self.exp_avg_sq = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg_sq") ... self.exp_avg_sq_row = Parameter(Tensor(np.zeros([2, 3]), mstype.float32), name="exp_avg_sq_row") ... self.exp_avg_sq_col = Parameter(Tensor(np.zeros([2, 2]), mstype.float32), name="exp_avg_sq_col") ... ... def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad): ... out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param, ... self.exp_avg, self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq) ... return out >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") >>> net = Net() >>> gradient = Tensor(np.ones(param_shape), mstype.float32) >>> net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient) """ @prim_attr_register def __init__(self, enable_scale_parameter=False, enable_first_moment=False, enable_weight_decay=False): self.add_prim_attr('side_effect_mem', True) validator.check_value_type("enable_scale_parameter", enable_scale_parameter, [bool], self.name) validator.check_value_type("enable_first_moment", enable_first_moment, [bool], self.name) validator.check_value_type("enable_weight_decay", enable_weight_decay, [bool], self.name) def infer_shape(self, epsilon_shape, clip_threshold_shape, beta1_shape, beta2t_shape, weight_decay_shape, learning_rate_shape, grad_shape, param_shape, exp_avg_shape, exp_avg_sq_row_shape, exp_avg_sq_col_shape, exp_avg_sq_shape): validator.check("grad_shape", grad_shape, "param_shape", param_shape, Rel.EQ, self.name) return param_shape def infer_dtype(self, epsilon_type, clip_threshold_type, beta1_type, beta2t_type, weight_decay_type, learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type, exp_avg_sq_col_type, exp_avg_sq_type): return param_type class FusedAdaFactorWithGlobalNorm(FusedAdaFactor): r""" Divide global norm for gradient in FusedAdaFactor, and refer to super class for FusedAdaFactor details """ @prim_attr_register def __init__(self, enable_scale_parameter=False, enable_first_moment=False, enable_weight_decay=False): super(FusedAdaFactorWithGlobalNorm, self).__init__(enable_scale_parameter, enable_first_moment, enable_weight_decay) def infer_shape(self, epsilon_shape, clip_threshold_shape, beta1_shape, beta2t_shape, weight_decay_shape, learning_rate_shape, grad_shape, param_shape, exp_avg_shape, exp_avg_sq_row_shape, exp_avg_sq_col_shape, exp_avg_sq_shape, global_norm_shape): validator.check("grad_shape", grad_shape, "param_shape", param_shape, Rel.EQ, self.name) return param_shape def infer_dtype(self, epsilon_type, clip_threshold_type, beta1_type, beta2t_type, weight_decay_type, learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type, exp_avg_sq_col_type, exp_avg_sq_type, global_norm_type): return param_type