Source code for mindspore.nn.optim.thor

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor"""
import numpy as np
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context


# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5

op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
    """Get grad with weight_decay."""
    if if_apply:
        return op_add((weight * weight_decay, gradient))
    return gradient


@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
    """Apply momentum optimizer to the weight parameter using Tensor."""
    success = True
    success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
    return success

IS_ENABLE_GLOBAL_NORM = False
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
hyper_map_op = C.HyperMap()


@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
    """
    Clip gradients.

    Inputs:
        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
        clip_value (float): Specifies how much to clip.
        grad (tuple[Tensor]): Gradients.

    Outputs:
        tuple[Tensor], clipped gradients.
    """
    if clip_type not in [0, 1]:
        return grad
    dt = F.dtype(grad)
    if clip_type == 0:
        new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
                                   F.cast(F.tuple_to_array((clip_value,)), dt))
    else:
        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
    return new_grad


def clip_gradient(enable_clip_grad, gradients):
    """clip gradients"""
    if enable_clip_grad:
        if IS_ENABLE_GLOBAL_NORM:
            gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None)
        else:
            gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients)
    return gradients

C0 = 16


def _check_param(momentum, frequency, lr, cls_name):
    """Check param."""
    Validator.check_value_type("momentum", momentum, [float], cls_name)
    if isinstance(momentum, float) and momentum < 0.0:
        raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
    Validator.check_value_type("frequency", frequency, [int], cls_name)
    if isinstance(frequency, int) and frequency < 2:
        raise ValueError("frequency should be at least 2, but got frequency {}".format(frequency))
    Validator.check_value_type("learning rate", lr, [Tensor], cls_name)


def caculate_device_shape(matrix_dim, channel, is_a):
    if is_a:
        if channel // C0 == 0:
            matrix_dim = (matrix_dim / channel) * C0
    ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
    return ll


def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
    """is conv layer matmul support shape"""
    temp = (matrix_g_shape, matrix_a_shape)
    support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
                     ((4, 4, 16, 16), (4, 4, 16, 16)),
                     ((4, 4, 16, 16), (36, 36, 16, 16)),
                     ((16, 16, 16, 16), (4, 4, 16, 16)),
                     ((4, 4, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (72, 72, 16, 16)),
                     ((32, 32, 16, 16), (8, 8, 16, 16)),
                     ((32, 32, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (144, 144, 16, 16)),
                     ((64, 64, 16, 16), (16, 16, 16, 16)),
                     ((64, 64, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (288, 288, 16, 16)),
                     ((128, 128, 16, 16), (32, 32, 16, 16)),
                     ((128, 128, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (128, 128, 16, 16))]
    if temp in support_shape:
        return True
    return False


def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
    """get matmul shape"""
    split_dima = split_dim
    split_dimg = split_dim
    if matrix_a_dim % split_dim == 0:
        batch_w = matrix_a_dim // split_dim
    else:
        if matrix_a_dim < split_dim:
            batch_w = 1
            split_dima = matrix_a_dim
        else:
            batch_w = matrix_a_dim // split_dim + 1

    if matrix_g_dim % split_dim == 0:
        batch_h = matrix_g_dim // split_dim
    else:
        if matrix_g_dim < split_dim:
            batch_h = 1
            split_dimg = matrix_g_dim
        else:
            batch_h = matrix_g_dim // split_dim + 1
    matrix_a_shape = (batch_h, batch_w, split_dima, split_dima)
    matrix_g_shape = (batch_h, split_dimg, split_dimg)
    return matrix_a_shape, matrix_g_shape


def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
    """get layer type for dense layer and conv layer"""
    if subcell.weight.requires_grad:
        if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
                or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
            layertype_map.append(Other)


def find_net_layertype_recur(net, layertype_map):
    """get net layer type recursively."""
    cells = net.name_cells()
    for name in cells:
        subcell = cells[name]
        prefix = subcell.param_prefix
        if subcell == net:
            continue
        elif isinstance(subcell, Conv2dThor):
            layertype_map.append(Conv)
        elif isinstance(subcell, DenseThor):
            layertype_map.append(FC)
        elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
            layertype_map.append(Embedding)
        elif isinstance(subcell, nn.LayerNorm):
            layertype_map.append(LayerNorm)
        elif isinstance(subcell, nn.BatchNorm2d):
            if subcell.gamma.requires_grad:
                layertype_map.append(BatchNorm)
        elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
                                  nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
            if isinstance(subcell, (nn.Dense, nn.Conv2d)):
                get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
            else:
                layertype_map.append(Other)
        else:
            find_net_layertype_recur(subcell, layertype_map)


def get_net_layertype_mask(net):
    layertype_map = []
    find_net_layertype_recur(net, layertype_map)
    return layertype_map


def get_layer_counter(layer_type, layer_counter, params, idx):
    """get layer counter"""
    if layer_type in [Conv, FC]:
        if "bias" in params[idx].name.lower():
            layer_counter = layer_counter + 1
        else:
            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
                layer_counter = layer_counter + 1
    elif layer_type in [LayerNorm, BatchNorm]:
        if "beta" in params[idx].name.lower():
            layer_counter = layer_counter + 1
    else:
        if "bias" in params[idx].name.lower():
            layer_counter = layer_counter + 1
        elif "weight" in params[idx].name.lower():
            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
                layer_counter = layer_counter + 1
        else:
            layer_counter = layer_counter + 1
    return layer_counter


[docs]def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): r""" Updates gradients by second-order algorithm--THOR. Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in: `THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation <https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_ The updating formulas are as follows, .. math:: \begin{array}{ll} \\ A_i = a_i{a_i}^T \\ G_i = D_{s_i}{ D_{s_i}}^T \\ m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\ w_i = w_i - \alpha * m_i \\ \end{array} :math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer, :math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer, :math:`\beta` represents momentum, :math:`I` represents the identity matrix, :math:`\overline A` represents the transpose of matrix A, :math:`\lambda` represents 'damping', :math:`g_i` represents gradients of the i-th layer, :math:`\otimes` represents Kronecker product, :math:`\alpha` represents 'learning rate' Note: When separating parameter groups, the weight decay in each group will be applied on the parameters if the weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True, but the gradient centralization can only be applied to the parameters of the convolution layer. If the parameters of the non convolution layer are set to True, an error will be reported. To improve parameter groups performance, the customized order of parameters can be supported. Args: net (Cell): The training network. learning_rate (Tensor): A value for the learning rate. damping (Tensor): A value for the damping. momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0. weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0. loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the default value. Default: 1.0. batch_size (int): The size of a batch. Default: 32 use_nesterov (bool): Enable Nesterov momentum. Default: False. decay_filter (function): A function to determine which layers the weight decay applied to. And it only works when the weight_decay > 0. Default: lambda x: x.name not in [] split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other is 27~53. Default: None enable_clip_grad (bool): Whether to clip the gradients. Default: False frequency(int): The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1), A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and $A^{-1}/G^{-1}$ to update weights. Default: 100. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: tuple[bool], all elements are True. Raises: TypeError: If `learning_rate` is not Tensor. TypeError: If `loss_scale`,`momentum` or `frequency` is not a float. TypeError: If `weight_decay` is neither float nor int. TypeError: If `use_nesterov` is not a bool. ValueError: If `loss_scale` is less than or equal to 0. ValueError: If `weight_decay` or `momentum` is less than 0. ValueError: If `frequency` is not int. ValueError: If `frequency` is less than 2. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore.nn import thor >>> from mindspore import Model >>> from mindspore import FixedLossScaleManager >>> from mindspore.train.callback import LossMonitor >>> from mindspore.train.train_thor import ConvertModelUtils >>> from mindspore import nn >>> from mindspore import Tensor >>> >>> net = Net() >>> dataset = create_dataset() >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32) >>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4) >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> loss_scale = FixedLossScaleManager(128, drop_overflow_update=False) >>> model = Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'}, ... amp_level="O2", keep_batchnorm_fp32=False) >>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim, ... loss_scale_manager=loss_scale, metrics={'acc'}, ... amp_level="O2", keep_batchnorm_fp32=False) >>> loss_cb = LossMonitor() >>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True) """ context.set_context(max_call_depth=10000) ConvertNetUtils().convert_to_thor_net(net) if context.get_context("device_target") == "Ascend": return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency) return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
class ThorGpu(Optimizer): """ ThorGpu """ def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): params = filter(lambda x: x.requires_grad, net.get_parameters()) super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale) _check_param(momentum, frequency, learning_rate, self.__class__.__name__) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.use_nesterov = Validator.check_bool(use_nesterov) self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) self.net = net self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters())) self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters())) self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters())) self.batch_size = Tensor(batch_size, mstype.float32) self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32) self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32) self.damping = damping self._define_gpu_operator() logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov))) self.thor = True self.matrix_a = () self.matrix_g = () self.matrix_a_shape = () self.thor_layer_count = 0 self.conv_layer_count = 0 self.weight_fim_idx_map = () self.weight_conv_idx_map = () self.weight_layertype_idx_map = () self._process_matrix_init_and_weight_idx_map(self.net) self.matrix_a = ParameterTuple(self.matrix_a) self.matrix_g = ParameterTuple(self.matrix_g) self.weight_decay = weight_decay self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim) self.enable_clip_grad = enable_clip_grad self.frequency = frequency self._define_gpu_reducer(split_indices) def get_frequency(self): """get thor frequency""" return self.frequency def _define_gpu_operator(self): """define gpu operator""" self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() self.matmul = P.MatMul() self.assign = P.Assign() self.mul = P.Mul() self.gather = P.GatherV2() self.one = Tensor(1, mstype.int32) self.feature_map = Tensor(1.0, mstype.float32) self.axis = 0 self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) self.cast = P.Cast() self.sqrt = P.Sqrt() self.eye = P.Eye() self.split_dim = 128 self.embedding_cholesky = P.CholeskyTrsm() self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim) self.vector_matmul = P.BatchMatMul(transpose_a=True) self.reduce_sum = P.ReduceSum(keep_dims=False) self.inv = P.Reciprocal() self.square = P.Square() self.expand = P.ExpandDims() def _define_gpu_reducer(self, split_indices): """define gpu reducer""" self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) if self.is_distributed: mean = _get_gradients_mean() degree = _get_device_num() if not split_indices: self.split_indices = [len(self.matrix_a_cov) - 1] else: self.split_indices = split_indices auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) def _process_matrix_init_and_weight_idx_map(self, net): """for GPU, process matrix init shape, and get weight idx map""" layer_type_map = get_net_layertype_mask(net) layer_counter = 0 for idx in range(len(self.params)): layer_type = layer_type_map[layer_counter] weight = self.params[idx] weight_shape = self.shape(weight) if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower(): in_channels = weight_shape[1] out_channels = weight_shape[0] matrix_a_dim = in_channels if layer_type == Conv: matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3] matrix_g_dim = out_channels matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim) matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.matrix_a = self.matrix_a + (matrix_a_inv,) self.matrix_g = self.matrix_g + (matrix_g_inv,) self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,) elif layer_type == Embedding: vocab_size = weight_shape[0] embedding_size = weight_shape[1] matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.matrix_a = self.matrix_a + (matrix_a_inv,) self.matrix_g = self.matrix_g + (matrix_g_inv,) self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),) if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower(): self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,) self.thor_layer_count = self.thor_layer_count + 1 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,) if layer_type == Conv: self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,) self.conv_layer_count = self.conv_layer_count + 1 else: self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) else: self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,) if layer_type == LayerNorm: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,) else: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,) # bert.cls1.output_bias: not a network layer, only a trainable param if "output_bias" not in self.params[idx].name.lower(): layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx) def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce): """get matrixA inverse list and matrix G inverse list""" for i in range(len(self.params)): thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] if layer_type in [Conv, FC, Embedding]: g = gradients[i] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) damping_a = damping_step damping_g = damping_step feature_map = self.feature_map if layer_type == Conv: a_normalizer = self.a_normalizer[conv_layer_count] g_normalizer = self.g_normalizer[conv_layer_count] a_normalizer = F.depend(a_normalizer, g) g_normalizer = F.depend(g_normalizer, g) damping_a = self.mul(damping_step, 1.0 / a_normalizer) damping_g = self.mul(damping_step, 1.0 / g_normalizer) feature_map = self.sqrt(1.0 / a_normalizer) a_shape = self.shape(matrix_a) a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) damping_a = self.sqrt(damping_a) damping_g = self.sqrt(damping_g) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32) matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping_g * g_eye if layer_type == Embedding: a_eye = P.OnesLike()(matrix_a) matrix_a = self.mul(matrix_a, 1.0 / self.batch_size) matrix_a = matrix_a + damping_a * a_eye matrix_a = self.inv(matrix_a) matrix_g = self.embedding_cholesky(matrix_g) matrix_g = self.matmul(matrix_g, matrix_g) else: matrix_a = matrix_a + damping_a * a_eye matrix_a = self.cholesky(matrix_a) matrix_a = self.vector_matmul(matrix_a, matrix_a) matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a) matrix_g = self.cholesky(matrix_g) matrix_g = self.vector_matmul(matrix_g, matrix_g) matrix_a = self.mul(matrix_a, feature_map) matrix_g = self.mul(matrix_g, feature_map) matrix_a_allreduce = matrix_a_allreduce + (matrix_a,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g,) return matrix_a_allreduce, matrix_g_allreduce def _process_layernorm(self, damping_step, gradient): """process layernorm""" damping = self.sqrt(damping_step) normalizer = self.batch_size normalizer = self.cast(normalizer, mstype.float32) fim_cov = self.square(gradient) fim_cov = self.mul(fim_cov, 1.0 / normalizer) fim_cov = fim_cov + damping fim_inv = self.inv(fim_cov) gradient = self.mul(fim_inv, gradient) return gradient def _reshape_gradient(self, conv_layer_count, g, g_shape): """reshape gradient""" if conv_layer_count != -1: g = self.reshape(g, g_shape) return g def construct(self, gradients): params = self.params moments = self.moments gradients = self.scale_grad(gradients) damping_step = self.gather(self.damping, self.cov_step, self.axis) damping_step = self.cast(damping_step, mstype.float32) new_grads = () if self.thor: matrix_ainv_list = () matrix_ginv_list = () matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step, matrix_ainv_list, matrix_ginv_list) if self.is_distributed: matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce) matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce) for i in range(len(self.params)): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] if layer_type in [Conv, FC]: g_shape = self.shape(g) g = self.reshape(g, (g_shape[0], -1)) matrix_a = matrix_a_allreduce[thor_layer_count] matrix_g = matrix_g_allreduce[thor_layer_count] g = self.update_gradient(matrix_g, g, matrix_a) self.assign(self.matrix_a[thor_layer_count], matrix_a) self.assign(self.matrix_g[thor_layer_count], matrix_g) g = self._reshape_gradient(conv_layer_count, g, g_shape) elif layer_type == Embedding: matrix_a = matrix_a_allreduce[thor_layer_count] matrix_g = matrix_g_allreduce[thor_layer_count] self.assign(self.matrix_a[thor_layer_count], matrix_a) self.assign(self.matrix_g[thor_layer_count], matrix_g) temp_a = self.expand(matrix_a, 1) g = self.mul(temp_a, g) g = self.matmul(g, matrix_g) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) new_grads = new_grads + (g,) else: for j in range(len(self.params)): g = gradients[j] thor_layer_count = self.weight_fim_idx_map[j] conv_layer_count = self.weight_conv_idx_map[j] layer_type = self.weight_layertype_idx_map[j] if layer_type in [Conv, FC]: g_shape = self.shape(g) g = self.reshape(g, (g_shape[0], -1)) matrix_a = self.matrix_a[thor_layer_count] matrix_g = self.matrix_g[thor_layer_count] g = self.update_gradient(matrix_g, g, matrix_a) g = self._reshape_gradient(conv_layer_count, g, g_shape) elif layer_type == Embedding: matrix_a = self.matrix_a[thor_layer_count] matrix_g = self.matrix_g[thor_layer_count] g = gradients[j] temp_a = self.expand(matrix_a, 1) g = self.mul(temp_a, g) g = self.matmul(g, matrix_g) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) new_grads = new_grads + (g,) gradients = new_grads self.cov_step = self.cov_step + self.one if self.weight_decay > 0: gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) gradients = clip_gradient(self.enable_clip_grad, gradients) lr = self.get_lr() success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success class ThorAscend(Optimizer): """ThorAscend""" def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): params = filter(lambda x: x.requires_grad, net.get_parameters()) super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale) _check_param(momentum, frequency, learning_rate, self.__class__.__name__) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum() self.net = net self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters())) self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters())) self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters())) logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov))) self._define_ascend_operator() self.C0 = 16 self.device_shape_pad_flag = () self.diag_block_dim = 128 self.matrix_a = () self.matrix_g = () self.thor_layer_count = 0 self.conv_layer_count = 0 self.weight_conv_idx_map = () self.weight_fim_idx_map = () self.weight_layertype_idx_map = () self.a_split_pad_dim_map = () self.g_split_pad_dim_map = () self.conv_matmul_support_map = () self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36] self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32] self._process_matrix_init_and_weight_idx_map(self.net) self.matrix_a = ParameterTuple(self.matrix_a) self.matrix_g = ParameterTuple(self.matrix_g) self.matrix_max_inv = () for i in range(len(self.matrix_a)): self.matrix_max_inv = self.matrix_max_inv + ( Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) self.thor = True self.weight_decay = weight_decay self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.damping = damping self.batch_size = Tensor(batch_size, mstype.float32) self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32) self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32) self.enable_clip_grad = enable_clip_grad self.frequency = frequency self._define_ascend_reducer(split_indices) def get_frequency(self): """get thor frequency""" return self.frequency def _get_pad_dim(self, matrix_dim): """get diag split pad dim """ split_pad_dim = 0 if matrix_dim == 64: return split_pad_dim res = matrix_dim % self.diag_block_dim if res != 0: split_pad_dim = self.diag_block_dim - res return split_pad_dim def _define_ascend_operator(self): """define ascend operator""" self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() self.mul = P.Mul() self.log = P.Log() self.exp = P.Exp() self.sqrt = P.Sqrt() self.gather = P.GatherV2() self.assign = P.Assign() self.cast = P.Cast() self.eye = P.Eye() self.concat = P.Concat(0) self.cholesky = P.CusCholeskyTrsm() self.vector_matmul = P.CusBatchMatMul() self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True) self.fused_abs_max2 = P.CusFusedAbsMax1() self.matrix_combine = P.CusMatrixCombine() self.slice = P.Slice() self.expand = P.ExpandDims() self.reduce_sum = P.ReduceSum(keep_dims=False) self.square = P.Square() self.inv = P.Inv() self.matmul = P.MatMul() self.axis = 0 self.one = Tensor(1, mstype.int32) self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) def _define_ascend_reducer(self, split_indices): """define ascend reducer""" self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) if self.is_distributed: mean = _get_gradients_mean() degree = _get_device_num() if not split_indices: self.split_indices = [len(self.matrix_a_cov) - 1] else: self.split_indices = split_indices if self.conv_layer_count > 0: auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4") self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2) self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4) auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) def _get_weight_idx_map(self, layer_type, idx, weight_shape): """for Ascend, get weight idx map""" if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower(): self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,) self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,) if layer_type == Embedding: a_pad_dim = 0 g_pad_dim = 0 self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,) self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,) else: out_channels = weight_shape[0] g_pad_dim = self._get_pad_dim(out_channels) self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,) matrix_a_dim = weight_shape[1] if layer_type == Conv: matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3] a_pad_dim = self._get_pad_dim(matrix_a_dim) self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,) self.thor_layer_count = self.thor_layer_count + 1 if layer_type == Conv: self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,) self.conv_layer_count = self.conv_layer_count + 1 else: self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) else: self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,) self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) if layer_type == LayerNorm: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,) else: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,) def _get_fc_matrix(self, weight_shape): """for Ascend, get fc matrix_a and matrix_g""" out_channels = weight_shape[0] in_channels = weight_shape[1] if self.conv_layer_count > 0: if out_channels == 1001: fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) else: fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.matrix_a = self.matrix_a + (fc_matrix_a,) self.matrix_g = self.matrix_g + (fc_matrix_g,) def _process_matrix_init_and_weight_idx_map(self, net): """for Ascend, process matrix init shape, and get weight idx map""" layer_counter = 0 layer_type_map = get_net_layertype_mask(net) for idx in range(len(self.params)): layer_type = layer_type_map[layer_counter] weight = self.params[idx] weight_shape = self.shape(weight) if layer_type == Conv and "bias" not in self.params[idx].name.lower(): in_channels = weight_shape[1] out_channels = weight_shape[0] matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3] matrix_g_dim = out_channels matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True) matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False) ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape) if ret: matrix_a_inv = Parameter( Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter( Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.conv_matmul_support_map = self.conv_matmul_support_map + (1,) else: matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.conv_matmul_support_map = self.conv_matmul_support_map + (0,) self.matrix_a = self.matrix_a + (matrix_a_inv,) self.matrix_g = self.matrix_g + (matrix_g_inv,) device_shape_pad_flag = False if matrix_a_dim != matrix_a_device_dim: device_shape_pad_flag = True self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,) elif layer_type == FC and "bias" not in self.params[idx].name.lower(): self._get_fc_matrix(weight_shape) self._get_weight_idx_map(layer_type, idx, weight_shape) # bert.cls1.output_bias: not a network layer, only a trainable param if "output_bias" not in self.params[idx].name.lower(): layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx) def _process_batch_matmul(self, input_matrix): """process batch matmul""" input_matrix_shape = self.shape(input_matrix) if input_matrix_shape[0] in self.batch_matmul_support_list: input_matrix = self.vector_matmul(input_matrix, input_matrix) else: input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix) return input_matrix def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0): """process cholesky pad""" if pad_dim > 0: matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32) matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup) input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix) input_matrix = self.concat((input_matrix, matrix_sup)) return input_matrix def _get_abs_max(self, matrix_inv, origin_dim): """get matrix abs max""" cholesky_shape = self.shape(matrix_inv) if cholesky_shape[0] in self.abs_max_support_list: matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv) matrix_max = self.fused_abs_max2(matrix_inv_max) matrix_inv = self.matrix_combine(matrix_inv) else: matrix_inv = self.matrix_combine(matrix_inv) matrix_abs = P.Abs()(matrix_inv) matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs) return matrix_max, matrix_inv def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce): """get fc layer ainv and ginv""" thor_layer_count = self.weight_fim_idx_map[index] g = gradients[index] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) a_shape = self.shape(matrix_a) a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) damping = self.sqrt(damping_step) matrix_a = matrix_a + damping * a_eye a_pad_dim = self.a_split_pad_dim_map[thor_layer_count] matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0]) matrix_a_inv = self.cholesky(matrix_a) matrix_a_inv = self._process_batch_matmul(matrix_a_inv) weight_shape = self.shape(self.params[index]) out_channels = weight_shape[0] in_channels = weight_shape[1] if out_channels == 2: matrix_a_inv = self.matrix_combine(matrix_a_inv) matrix_g_inv = g_eye else: matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping * g_eye g_pad_dim = self.g_split_pad_dim_map[thor_layer_count] matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0]) matrix_g_inv = self.cholesky(matrix_g) matrix_g_inv = self._process_batch_matmul(matrix_g_inv) if self.conv_layer_count > 0: a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels) g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels) a_max = F.depend(a_max, g) g_max = F.depend(g_max, g) matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,) matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,) else: matrix_a_inv = self.matrix_combine(matrix_a_inv) matrix_g_inv = self.matrix_combine(matrix_g_inv) if a_pad_dim > 0: matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels)) if g_pad_dim > 0: matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels)) matrix_a_inv_shape = self.shape(matrix_a_inv) matrix_g_combine_shape = self.shape(matrix_g_inv) if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001: matrix_a_inv = self.reshape(matrix_a_inv, (matrix_a_inv_shape[0] / 16, 16, matrix_a_inv_shape[0] / 16, 16)) matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv) matrix_g_inv_shape = self.shape(matrix_g_inv) matrix_g_inv = self.reshape(matrix_g_inv, (matrix_g_inv_shape[0] / 16, 16, matrix_g_inv_shape[0] / 16, 16)) matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv): """process conv matmul device pad""" if self.device_shape_pad_flag[conv_layer_count]: kernel_hw = weight_shape[2] * weight_shape[3] in_channels = weight_shape[1] matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels)) matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0), (0, self.C0 - in_channels)))(matrix_a_inv) return matrix_a_inv def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce): """get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list""" for i in range(len(self.params)): thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] weight_shape = self.shape(self.params[i]) out_channels = weight_shape[0] if layer_type == Conv: g = gradients[i] matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3] matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) a_shape = self.shape(matrix_a) a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) a_normalizer = self.a_normalizer[conv_layer_count] g_normalizer = self.g_normalizer[conv_layer_count] a_normalizer = F.depend(a_normalizer, g) g_normalizer = F.depend(g_normalizer, g) damping_a = self.mul(damping_step, self.batch_size / a_normalizer) damping_g = self.mul(damping_step, self.batch_size / g_normalizer) damping_a = self.sqrt(damping_a) matrix_a = matrix_a + damping_a * a_eye a_pad_dim = self.a_split_pad_dim_map[thor_layer_count] matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0]) matrix_a_inv = self.cholesky(matrix_a) matrix_a_inv = self._process_batch_matmul(matrix_a_inv) a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim) damping_g = self.sqrt(damping_g) matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping_g * g_eye g_pad_dim = self.g_split_pad_dim_map[thor_layer_count] matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0]) matrix_g_inv = self.cholesky(matrix_g) matrix_g_inv = self._process_batch_matmul(matrix_g_inv) g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels) if a_pad_dim > 0: matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim)) if g_pad_dim > 0: matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels)) if matmul_support_flag == 1: matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv) matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count]) matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2], matrix_a_inv_shape[1], matrix_a_inv_shape[3]) matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape) matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count]) matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2], matrix_g_inv_shape[1], matrix_g_inv_shape[3]) matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape) matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) a_max = F.depend(a_max, g) g_max = F.depend(g_max, g) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,) matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,) elif layer_type == FC: matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \ self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce) elif layer_type == Embedding: g = gradients[i] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) damping = self.sqrt(damping_step) a_eye = P.OnesLike()(matrix_a) matrix_a = self.mul(matrix_a, 1.0 / self.batch_size) matrix_a = matrix_a + damping * a_eye matrix_a_inv = self.inv(matrix_a) matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping * g_eye matrix_g_inv = self.cholesky(matrix_g) matrix_g_inv = self._process_batch_matmul(matrix_g_inv) matrix_g_inv = self.matrix_combine(matrix_g_inv) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce def _process_layernorm(self, damping_step, gradient): """process layernorm layer for thor""" damping = self.sqrt(damping_step) normalizer = self.cast(self.batch_size, mstype.float32) fim_cov = self.square(gradient) fim_cov = self.mul(fim_cov, 1.0 / normalizer) fim_cov = fim_cov + damping fim_inv = self.inv(fim_cov) gradient = self.mul(fim_inv, gradient) return gradient def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g): """process thor graph fc layer""" temp_a = matrix_a_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count] self.assign(self.matrix_a_cov[thor_layer_count], temp_a) self.assign(self.matrix_g_cov[thor_layer_count], temp_g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) return g def _get_second_gradients_one(self, params_len, gradients, new_grads): """get second gradients one""" for i in range(params_len): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] matrix_a = self.matrix_a[thor_layer_count] matrix_g = self.matrix_g[thor_layer_count] matrix_max = self.matrix_max_inv[thor_layer_count] grad_shape = self.shape(g) if layer_type == FC: if grad_shape[0] == 1001: g = self.cube_matmul_left_fc(matrix_g, g) g = self.cube_matmul_right_fc(g, matrix_a, matrix_max) else: temp_a = self.cast(matrix_a, mstype.float16) temp_g = self.cast(matrix_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, matrix_max) elif layer_type == Conv: matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] if matmul_support_flag == 1: g = self.cube_matmul_left(matrix_g, g) g = self.cube_matmul_right_mul(g, matrix_a, matrix_max) else: g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3])) temp_a = self.cast(matrix_a, mstype.float16) temp_g = self.cast(matrix_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, matrix_max) g = self.reshape(g, grad_shape) new_grads = new_grads + (g,) return new_grads def _get_second_gradients(self, new_grads, damping_step, gradients): """get second gradients for thor""" params_len = len(self.params) if self.conv_layer_count > 0: new_grads = self._get_second_gradients_one(params_len, gradients, new_grads) else: for i in range(params_len): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] layer_type = self.weight_layertype_idx_map[i] if layer_type == Embedding: temp_a_ori = self.matrix_a_cov[thor_layer_count] temp_g = self.matrix_g_cov[thor_layer_count] temp_a = self.expand(temp_a_ori, 1) g = self.mul(temp_a, g) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) elif layer_type == FC: temp_a = self.matrix_a_cov[thor_layer_count] temp_g = self.matrix_g_cov[thor_layer_count] temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) new_grads = new_grads + (g,) return new_grads def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max): """get second gradient by matmul""" conv_layer_count = self.weight_conv_idx_map[index] layer_type = self.weight_layertype_idx_map[index] grad_shape = self.shape(g) if layer_type == FC: if grad_shape[0] == 1001: g = self.cube_matmul_left_fc(temp_g, g) g = self.cube_matmul_right_fc(g, temp_a, temp_max) else: temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, temp_max) elif layer_type == Conv: a_normalizer = self.a_normalizer[conv_layer_count] a_normalizer = F.depend(a_normalizer, g) temp_max = self.mul(temp_max, self.batch_size / a_normalizer) matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] if matmul_support_flag == 1: g = self.cube_matmul_left(temp_g, g) g = self.cube_matmul_right_mul(g, temp_a, temp_max) else: g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3])) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, temp_max) g = self.reshape(g, grad_shape) return g, temp_max def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step): """get second gradient by layertype""" thor_layer_count = self.weight_fim_idx_map[index] layer_type = self.weight_layertype_idx_map[index] if layer_type == Embedding: temp_a_ori = matrix_a_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count] self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori) self.assign(self.matrix_g_cov[thor_layer_count], temp_g) temp_a = self.expand(temp_a_ori, 1) g = self.mul(temp_a, g) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) elif layer_type == FC: g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) return g def construct(self, gradients): params = self.params moments = self.moments gradients = self.scale_grad(gradients) damping_step = self.gather(self.damping, self.cov_step, self.axis) damping_step = self.cast(damping_step, mstype.float32) if self.thor: matrix_a_allreduce = () matrix_g_allreduce = () matrix_a_max_allreduce = () matrix_g_max_allreduce = () matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \ self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce) if self.is_distributed: matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce) matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce) if self.conv_layer_count > 0: matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce) matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce) new_grads = () if self.conv_layer_count > 0: for i in range(len(self.params)): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] temp_a = matrix_a_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count] matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count]) matrix_a_inv_max = self.mul(matrix_a_inv_max, -1) matrix_a_inv_max = self.exp(matrix_a_inv_max) temp_a = self.mul(temp_a, matrix_a_inv_max) matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count]) matrix_g_inv_max = self.mul(matrix_g_inv_max, -1) matrix_g_inv_max = self.exp(matrix_g_inv_max) temp_g = self.mul(temp_g, matrix_g_inv_max) temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count], matrix_g_max_allreduce[thor_layer_count]) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max) self.assign(self.matrix_a[thor_layer_count], temp_a) self.assign(self.matrix_g[thor_layer_count], temp_g) self.assign(self.matrix_max_inv[thor_layer_count], temp_max) new_grads = new_grads + (g,) gradients = new_grads else: for i in range(len(self.params)): g = gradients[i] g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step) new_grads = new_grads + (g,) gradients = new_grads else: new_grads = () gradients = self._get_second_gradients(new_grads, damping_step, gradients) self.cov_step = self.cov_step + self.one if self.weight_decay > 0: gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) gradients = clip_gradient(self.enable_clip_grad, gradients) lr = self.get_lr() success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success