# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor"""
import numpy as np
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context
# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5
op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
IS_ENABLE_GLOBAL_NORM = False
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
hyper_map_op = C.HyperMap()
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in [0, 1]:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
def clip_gradient(enable_clip_grad, gradients):
"""clip gradients"""
if enable_clip_grad:
if IS_ENABLE_GLOBAL_NORM:
gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None)
else:
gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients)
return gradients
C0 = 16
def _check_param(momentum, frequency, lr, cls_name):
"""Check param."""
Validator.check_value_type("momentum", momentum, [float], cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("For 'thor', the argument 'momentum' should be at least 0.0, "
"but got 'momentum' {}.".format(momentum))
Validator.check_value_type("frequency", frequency, [int], cls_name)
if isinstance(frequency, int) and frequency < 2:
raise ValueError("For 'thor', the argument 'frequency' should be at least 2, "
"but got 'frequency' {}.".format(frequency))
Validator.check_value_type("learning rate", lr, [Tensor], cls_name)
def caculate_device_shape(matrix_dim, channel, is_a):
if is_a:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll
def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
"""is conv layer matmul support shape"""
temp = (matrix_g_shape, matrix_a_shape)
support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
((4, 4, 16, 16), (4, 4, 16, 16)),
((4, 4, 16, 16), (36, 36, 16, 16)),
((16, 16, 16, 16), (4, 4, 16, 16)),
((4, 4, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (72, 72, 16, 16)),
((32, 32, 16, 16), (8, 8, 16, 16)),
((32, 32, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (144, 144, 16, 16)),
((64, 64, 16, 16), (16, 16, 16, 16)),
((64, 64, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (288, 288, 16, 16)),
((128, 128, 16, 16), (32, 32, 16, 16)),
((128, 128, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (128, 128, 16, 16))]
if temp in support_shape:
return True
return False
def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
"""get matmul shape"""
split_dima = split_dim
split_dimg = split_dim
if matrix_a_dim % split_dim == 0:
batch_w = matrix_a_dim // split_dim
else:
if matrix_a_dim < split_dim:
batch_w = 1
split_dima = matrix_a_dim
else:
batch_w = matrix_a_dim // split_dim + 1
if matrix_g_dim % split_dim == 0:
batch_h = matrix_g_dim // split_dim
else:
if matrix_g_dim < split_dim:
batch_h = 1
split_dimg = matrix_g_dim
else:
batch_h = matrix_g_dim // split_dim + 1
matrix_a_shape = (batch_h, batch_w, split_dima, split_dima)
matrix_g_shape = (batch_h, split_dimg, split_dimg)
return matrix_a_shape, matrix_g_shape
def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
"""get layer type for dense layer and conv layer"""
if subcell.weight.requires_grad:
if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
layertype_map.append(Other)
def find_net_layertype_recur(net, layertype_map):
"""get net layer type recursively."""
cells = net.name_cells()
for name in cells:
subcell = cells[name]
prefix = subcell.param_prefix
if subcell == net:
continue
elif isinstance(subcell, Conv2dThor):
layertype_map.append(Conv)
elif isinstance(subcell, DenseThor):
layertype_map.append(FC)
elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
layertype_map.append(Embedding)
elif isinstance(subcell, nn.LayerNorm):
layertype_map.append(LayerNorm)
elif isinstance(subcell, nn.BatchNorm2d):
if subcell.gamma.requires_grad:
layertype_map.append(BatchNorm)
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
else:
layertype_map.append(Other)
else:
find_net_layertype_recur(subcell, layertype_map)
def get_net_layertype_mask(net):
layertype_map = []
find_net_layertype_recur(net, layertype_map)
return layertype_map
def get_layer_counter(layer_type, layer_counter, params, idx):
"""get layer counter"""
if layer_type in [Conv, FC]:
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
elif layer_type in [LayerNorm, BatchNorm]:
if "beta" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
elif "weight" in params[idx].name.lower():
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
else:
layer_counter = layer_counter + 1
return layer_counter
[docs]def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False,
frequency=100):
r"""
Updates gradients by second-order algorithm--THOR.
Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:
`THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation
<https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
A_i = a_i{a_i}^T \\
G_i = D_{s_i}{ D_{s_i}}^T \\
m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\
w_i = w_i - \alpha * m_i \\
\end{array}
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer,
:math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer,
:math:`\beta` represents momentum, :math:`I` represents the identity matrix,
:math:`\overline A` represents the transpose of matrix A,
:math:`\lambda` represents 'damping', :math:`g_i` represents gradients of the i-th layer,
:math:`\otimes` represents Kronecker product, :math:`\alpha` represents 'learning rate'
Args:
net (Cell): The training network.
learning_rate (Tensor): A value for the learning rate.
damping (Tensor): A value for the damping.
momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
default value. Default: 1.0.
batch_size (int): The size of a batch. Default: 32
use_nesterov (bool): Enable Nesterov momentum. Default: False.
decay_filter (function): A function to determine which layers the weight decay applied to. And it
only works when the weight_decay > 0. Default: lambda x: x.name not in []
split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other
is 27~53. Default: None
enable_clip_grad (bool): Whether to clip the gradients. Default: False
frequency(int): The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1),
A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and
$A^{-1}/G^{-1}$ to update weights. Default: 100.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Raises:
TypeError: If `learning_rate` is not Tensor.
TypeError: If `loss_scale`, `momentum` or `frequency` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_nesterov` is not a bool.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `weight_decay` or `momentum` is less than 0.
ValueError: If `frequency` is not int.
ValueError: If `frequency` is less than 2.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore.nn import thor
>>> from mindspore import Model
>>> from mindspore import FixedLossScaleManager
>>> from mindspore.train.callback import LossMonitor
>>> from mindspore.train.train_thor import ConvertModelUtils
>>> from mindspore import nn
>>> from mindspore import Tensor
>>>
>>> net = Net()
>>> dataset = create_dataset()
>>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
>>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> loss_scale = FixedLossScaleManager(128, drop_overflow_update=False)
>>> model = Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
... loss_scale_manager=loss_scale, metrics={'acc'},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> loss_cb = LossMonitor()
>>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True)
"""
context.set_context(max_call_depth=10000)
ConvertNetUtils().convert_to_thor_net(net)
if context.get_context("device_target") == "Ascend":
return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size,
use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad,
frequency=frequency)
class ThorGpu(Optimizer):
"""
ThorGpu
"""
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None,
enable_clip_grad=False, frequency=100):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.damping = damping
self._define_gpu_operator()
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
self.thor = True
self.matrix_a = ()
self.matrix_g = ()
self.matrix_a_shape = ()
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_fim_idx_map = ()
self.weight_conv_idx_map = ()
self.weight_layertype_idx_map = ()
self._process_matrix_init_and_weight_idx_map(self.net)
self.matrix_a = ParameterTuple(self.matrix_a)
self.matrix_g = ParameterTuple(self.matrix_g)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim)
self.enable_clip_grad = enable_clip_grad
self.frequency = frequency
self._define_gpu_reducer(split_indices)
def get_frequency(self):
"""get thor frequency"""
return self.frequency
def _define_gpu_operator(self):
"""define gpu operator"""
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.assign = P.Assign()
self.mul = P.Mul()
self.gather = P.GatherV2()
self.one = Tensor(1, mstype.int32)
self.feature_map = Tensor(1.0, mstype.float32)
self.axis = 0
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.cast = P.Cast()
self.sqrt = P.Sqrt()
self.eye = P.Eye()
self.split_dim = 128
self.embedding_cholesky = P.CholeskyTrsm()
self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.inv = P.Reciprocal()
self.square = P.Square()
self.expand = P.ExpandDims()
def _define_gpu_reducer(self, split_indices):
"""define gpu reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
if not split_indices:
self.split_indices = [len(self.matrix_a_cov) - 1]
else:
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for GPU, process matrix init shape, and get weight idx map"""
layer_type_map = get_net_layertype_mask(net)
layer_counter = 0
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_a_dim = in_channels
if layer_type == Conv:
matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_g_dim = out_channels
matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim)
matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,)
elif layer_type == Embedding:
vocab_size = weight_shape[0]
embedding_size = weight_shape[1]
matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),)
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.thor_layer_count = self.thor_layer_count + 1
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
else:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce):
"""get matrixA inverse list and matrix G inverse list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
if layer_type in [Conv, FC, Embedding]:
g = gradients[i]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
damping_a = damping_step
damping_g = damping_step
feature_map = self.feature_map
if layer_type == Conv:
a_normalizer = self.a_normalizer[conv_layer_count]
g_normalizer = self.g_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
g_normalizer = F.depend(g_normalizer, g)
damping_a = self.mul(damping_step, 1.0 / a_normalizer)
damping_g = self.mul(damping_step, 1.0 / g_normalizer)
feature_map = self.sqrt(1.0 / a_normalizer)
a_shape = self.shape(matrix_a)
a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
damping_a = self.sqrt(damping_a)
damping_g = self.sqrt(damping_g)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping_g * g_eye
if layer_type == Embedding:
a_eye = P.OnesLike()(matrix_a)
matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
matrix_a = matrix_a + damping_a * a_eye
matrix_a = self.inv(matrix_a)
matrix_g = self.embedding_cholesky(matrix_g)
matrix_g = self.matmul(matrix_g, matrix_g)
else:
matrix_a = matrix_a + damping_a * a_eye
matrix_a = self.cholesky(matrix_a)
matrix_a = self.vector_matmul(matrix_a, matrix_a)
matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a)
matrix_g = self.cholesky(matrix_g)
matrix_g = self.vector_matmul(matrix_g, matrix_g)
matrix_a = self.mul(matrix_a, feature_map)
matrix_g = self.mul(matrix_g, feature_map)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g,)
return matrix_a_allreduce, matrix_g_allreduce
def _process_layernorm(self, damping_step, gradient):
"""process layernorm"""
damping = self.sqrt(damping_step)
normalizer = self.batch_size
normalizer = self.cast(normalizer, mstype.float32)
fim_cov = self.square(gradient)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
gradient = self.mul(fim_inv, gradient)
return gradient
def _reshape_gradient(self, conv_layer_count, g, g_shape):
"""reshape gradient"""
if conv_layer_count != -1:
g = self.reshape(g, g_shape)
return g
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
new_grads = ()
if self.thor:
matrix_ainv_list = ()
matrix_ginv_list = ()
matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step,
matrix_ainv_list, matrix_ginv_list)
if self.is_distributed:
matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
for i in range(len(self.params)):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
if layer_type in [Conv, FC]:
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_a = matrix_a_allreduce[thor_layer_count]
matrix_g = matrix_g_allreduce[thor_layer_count]
g = self.update_gradient(matrix_g, g, matrix_a)
self.assign(self.matrix_a[thor_layer_count], matrix_a)
self.assign(self.matrix_g[thor_layer_count], matrix_g)
g = self._reshape_gradient(conv_layer_count, g, g_shape)
elif layer_type == Embedding:
matrix_a = matrix_a_allreduce[thor_layer_count]
matrix_g = matrix_g_allreduce[thor_layer_count]
self.assign(self.matrix_a[thor_layer_count], matrix_a)
self.assign(self.matrix_g[thor_layer_count], matrix_g)
temp_a = self.expand(matrix_a, 1)
g = self.mul(temp_a, g)
g = self.matmul(g, matrix_g)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
else:
for j in range(len(self.params)):
g = gradients[j]
thor_layer_count = self.weight_fim_idx_map[j]
conv_layer_count = self.weight_conv_idx_map[j]
layer_type = self.weight_layertype_idx_map[j]
if layer_type in [Conv, FC]:
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
g = self.update_gradient(matrix_g, g, matrix_a)
g = self._reshape_gradient(conv_layer_count, g, g_shape)
elif layer_type == Embedding:
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
g = gradients[j]
temp_a = self.expand(matrix_a, 1)
g = self.mul(temp_a, g)
g = self.matmul(g, matrix_g)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
gradients = new_grads
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
gradients = clip_gradient(self.enable_clip_grad, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success
class ThorAscend(Optimizer):
"""ThorAscend"""
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
self._define_ascend_operator()
self.C0 = 16
self.device_shape_pad_flag = ()
self.diag_block_dim = 128
self.matrix_a = ()
self.matrix_g = ()
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_conv_idx_map = ()
self.weight_fim_idx_map = ()
self.weight_layertype_idx_map = ()
self.a_split_pad_dim_map = ()
self.g_split_pad_dim_map = ()
self.conv_matmul_support_map = ()
self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36]
self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32]
self._process_matrix_init_and_weight_idx_map(self.net)
self.matrix_a = ParameterTuple(self.matrix_a)
self.matrix_g = ParameterTuple(self.matrix_g)
self.matrix_max_inv = ()
for i in range(len(self.matrix_a)):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.thor = True
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.damping = damping
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.enable_clip_grad = enable_clip_grad
self.frequency = frequency
self._define_ascend_reducer(split_indices)
def get_frequency(self):
"""get thor frequency"""
return self.frequency
def _get_pad_dim(self, matrix_dim):
"""get diag split pad dim """
split_pad_dim = 0
if matrix_dim == 64:
return split_pad_dim
res = matrix_dim % self.diag_block_dim
if res != 0:
split_pad_dim = self.diag_block_dim - res
return split_pad_dim
def _define_ascend_operator(self):
"""define ascend operator"""
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.gather = P.GatherV2()
self.assign = P.Assign()
self.cast = P.Cast()
self.eye = P.Eye()
self.concat = P.Concat(0)
self.cholesky = P.CusCholeskyTrsm()
self.vector_matmul = P.CusBatchMatMul()
self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True)
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.matrix_combine = P.CusMatrixCombine()
self.slice = P.Slice()
self.expand = P.ExpandDims()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.square = P.Square()
self.inv = P.Inv()
self.matmul = P.MatMul()
self.axis = 0
self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
def _define_ascend_reducer(self, split_indices):
"""define ascend reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
if not split_indices:
self.split_indices = [len(self.matrix_a_cov) - 1]
else:
self.split_indices = split_indices
if self.conv_layer_count > 0:
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
def _get_weight_idx_map(self, layer_type, idx, weight_shape):
"""for Ascend, get weight idx map"""
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Embedding:
a_pad_dim = 0
g_pad_dim = 0
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
else:
out_channels = weight_shape[0]
g_pad_dim = self._get_pad_dim(out_channels)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
matrix_a_dim = weight_shape[1]
if layer_type == Conv:
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
a_pad_dim = self._get_pad_dim(matrix_a_dim)
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
else:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
def _get_fc_matrix(self, weight_shape):
"""for Ascend, get fc matrix_a and matrix_g"""
out_channels = weight_shape[0]
in_channels = weight_shape[1]
if self.conv_layer_count > 0:
if out_channels == 1001:
fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
else:
fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
self.matrix_a = self.matrix_a + (fc_matrix_a,)
self.matrix_g = self.matrix_g + (fc_matrix_g,)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for Ascend, process matrix init shape, and get weight idx map"""
layer_counter = 0
layer_type_map = get_net_layertype_mask(net)
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type == Conv and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_g_dim = out_channels
matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True)
matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False)
ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape)
if ret:
matrix_a_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.conv_matmul_support_map = self.conv_matmul_support_map + (1,)
else:
matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.conv_matmul_support_map = self.conv_matmul_support_map + (0,)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
device_shape_pad_flag = False
if matrix_a_dim != matrix_a_device_dim:
device_shape_pad_flag = True
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
self._get_fc_matrix(weight_shape)
self._get_weight_idx_map(layer_type, idx, weight_shape)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
def _process_batch_matmul(self, input_matrix):
"""process batch matmul"""
input_matrix_shape = self.shape(input_matrix)
if input_matrix_shape[0] in self.batch_matmul_support_list:
input_matrix = self.vector_matmul(input_matrix, input_matrix)
else:
input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix)
return input_matrix
def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0):
"""process cholesky pad"""
if pad_dim > 0:
matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32)
matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup)
input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix)
input_matrix = self.concat((input_matrix, matrix_sup))
return input_matrix
def _get_abs_max(self, matrix_inv, origin_dim):
"""get matrix abs max"""
cholesky_shape = self.shape(matrix_inv)
if cholesky_shape[0] in self.abs_max_support_list:
matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv)
matrix_max = self.fused_abs_max2(matrix_inv_max)
matrix_inv = self.matrix_combine(matrix_inv)
else:
matrix_inv = self.matrix_combine(matrix_inv)
matrix_abs = P.Abs()(matrix_inv)
matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs)
return matrix_max, matrix_inv
def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get fc layer ainv and ginv"""
thor_layer_count = self.weight_fim_idx_map[index]
g = gradients[index]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
a_shape = self.shape(matrix_a)
a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
damping = self.sqrt(damping_step)
matrix_a = matrix_a + damping * a_eye
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
matrix_a_inv = self.cholesky(matrix_a)
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
weight_shape = self.shape(self.params[index])
out_channels = weight_shape[0]
in_channels = weight_shape[1]
if out_channels == 2:
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_g_inv = g_eye
else:
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping * g_eye
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
if self.conv_layer_count > 0:
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels)
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
a_max = F.depend(a_max, g)
g_max = F.depend(g_max, g)
matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
else:
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
if a_pad_dim > 0:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels))
if g_pad_dim > 0:
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
matrix_a_inv_shape = self.shape(matrix_a_inv)
matrix_g_combine_shape = self.shape(matrix_g_inv)
if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
matrix_a_inv = self.reshape(matrix_a_inv,
(matrix_a_inv_shape[0] / 16, 16,
matrix_a_inv_shape[0] / 16, 16))
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
matrix_g_inv_shape = self.shape(matrix_g_inv)
matrix_g_inv = self.reshape(matrix_g_inv,
(matrix_g_inv_shape[0] / 16, 16,
matrix_g_inv_shape[0] / 16, 16))
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv):
"""process conv matmul device pad"""
if self.device_shape_pad_flag[conv_layer_count]:
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_a_inv)
return matrix_a_inv
def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
weight_shape = self.shape(self.params[i])
out_channels = weight_shape[0]
if layer_type == Conv:
g = gradients[i]
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
a_shape = self.shape(matrix_a)
a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
a_normalizer = self.a_normalizer[conv_layer_count]
g_normalizer = self.g_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
g_normalizer = F.depend(g_normalizer, g)
damping_a = self.mul(damping_step, self.batch_size / a_normalizer)
damping_g = self.mul(damping_step, self.batch_size / g_normalizer)
damping_a = self.sqrt(damping_a)
matrix_a = matrix_a + damping_a * a_eye
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
matrix_a_inv = self.cholesky(matrix_a)
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim)
damping_g = self.sqrt(damping_g)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping_g * g_eye
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
if a_pad_dim > 0:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim))
if g_pad_dim > 0:
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
if matmul_support_flag == 1:
matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv)
matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
matrix_g_inv_shape[1], matrix_g_inv_shape[3])
matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape)
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
a_max = F.depend(a_max, g)
g_max = F.depend(g_max, g)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
elif layer_type == FC:
matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce)
elif layer_type == Embedding:
g = gradients[i]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
damping = self.sqrt(damping_step)
a_eye = P.OnesLike()(matrix_a)
matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
matrix_a = matrix_a + damping * a_eye
matrix_a_inv = self.inv(matrix_a)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping * g_eye
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
def _process_layernorm(self, damping_step, gradient):
"""process layernorm layer for thor"""
damping = self.sqrt(damping_step)
normalizer = self.cast(self.batch_size, mstype.float32)
fim_cov = self.square(gradient)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
gradient = self.mul(fim_inv, gradient)
return gradient
def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g):
"""process thor graph fc layer"""
temp_a = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count]
self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
return g
def _get_second_gradients_one(self, params_len, gradients, new_grads):
"""get second gradients one"""
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
grad_shape = self.shape(g)
if layer_type == FC:
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(matrix_g, g)
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
else:
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
elif layer_type == Conv:
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(matrix_g, g)
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
g = self.reshape(g, grad_shape)
new_grads = new_grads + (g,)
return new_grads
def _get_second_gradients(self, new_grads, damping_step, gradients):
"""get second gradients for thor"""
params_len = len(self.params)
if self.conv_layer_count > 0:
new_grads = self._get_second_gradients_one(params_len, gradients, new_grads)
else:
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
if layer_type == Embedding:
temp_a_ori = self.matrix_a_cov[thor_layer_count]
temp_g = self.matrix_g_cov[thor_layer_count]
temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
elif layer_type == FC:
temp_a = self.matrix_a_cov[thor_layer_count]
temp_g = self.matrix_g_cov[thor_layer_count]
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
return new_grads
def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max):
"""get second gradient by matmul"""
conv_layer_count = self.weight_conv_idx_map[index]
layer_type = self.weight_layertype_idx_map[index]
grad_shape = self.shape(g)
if layer_type == FC:
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
else:
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
elif layer_type == Conv:
a_normalizer = self.a_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
temp_max = self.mul(temp_max, self.batch_size / a_normalizer)
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
g = self.reshape(g, grad_shape)
return g, temp_max
def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step):
"""get second gradient by layertype"""
thor_layer_count = self.weight_fim_idx_map[index]
layer_type = self.weight_layertype_idx_map[index]
if layer_type == Embedding:
temp_a_ori = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count]
self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
elif layer_type == FC:
g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
return g
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
if self.thor:
matrix_a_allreduce = ()
matrix_g_allreduce = ()
matrix_a_max_allreduce = ()
matrix_g_max_allreduce = ()
matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce)
if self.is_distributed:
matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
if self.conv_layer_count > 0:
matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce)
matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce)
new_grads = ()
if self.conv_layer_count > 0:
for i in range(len(self.params)):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
temp_a = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count]
matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count])
matrix_a_inv_max = self.mul(matrix_a_inv_max, -1)
matrix_a_inv_max = self.exp(matrix_a_inv_max)
temp_a = self.mul(temp_a, matrix_a_inv_max)
matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count])
matrix_g_inv_max = self.mul(matrix_g_inv_max, -1)
matrix_g_inv_max = self.exp(matrix_g_inv_max)
temp_g = self.mul(temp_g, matrix_g_inv_max)
temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count],
matrix_g_max_allreduce[thor_layer_count])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max)
self.assign(self.matrix_a[thor_layer_count], temp_a)
self.assign(self.matrix_g[thor_layer_count], temp_g)
self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
new_grads = new_grads + (g,)
gradients = new_grads
else:
for i in range(len(self.params)):
g = gradients[i]
g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step)
new_grads = new_grads + (g,)
gradients = new_grads
else:
new_grads = ()
gradients = self._get_second_gradients(new_grads, damping_step, gradients)
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
gradients = clip_gradient(self.enable_clip_grad, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success