# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor"""
import numpy as np
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context
# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5
op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
IS_ENABLE_GLOBAL_NORM = False
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
hyper_map_op = C.HyperMap()
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in [0, 1]:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
def clip_gradient(enable_clip_grad, gradients):
"""clip gradients"""
if enable_clip_grad:
if IS_ENABLE_GLOBAL_NORM:
gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None)
else:
gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients)
return gradients
C0 = 16
def _check_param(momentum, frequency, lr, cls_name):
"""Check param."""
Validator.check_value_type("momentum", momentum, [float], cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
Validator.check_value_type("frequency", frequency, [int], cls_name)
if isinstance(frequency, int) and frequency < 2:
raise ValueError("frequency should be at least 2, but got frequency {}".format(frequency))
Validator.check_value_type("learning rate", lr, [Tensor], cls_name)
def caculate_device_shape(matrix_dim, channel, is_a):
if is_a:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll
def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
"""is conv layer matmul support shape"""
temp = (matrix_g_shape, matrix_a_shape)
support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
((4, 4, 16, 16), (4, 4, 16, 16)),
((4, 4, 16, 16), (36, 36, 16, 16)),
((16, 16, 16, 16), (4, 4, 16, 16)),
((4, 4, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (72, 72, 16, 16)),
((32, 32, 16, 16), (8, 8, 16, 16)),
((32, 32, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (144, 144, 16, 16)),
((64, 64, 16, 16), (16, 16, 16, 16)),
((64, 64, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (288, 288, 16, 16)),
((128, 128, 16, 16), (32, 32, 16, 16)),
((128, 128, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (128, 128, 16, 16))]
if temp in support_shape:
return True
return False
def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
"""get matmul shape"""
split_dima = split_dim
split_dimg = split_dim
if matrix_a_dim % split_dim == 0:
batch_w = matrix_a_dim // split_dim
else:
if matrix_a_dim < split_dim:
batch_w = 1
split_dima = matrix_a_dim
else:
batch_w = matrix_a_dim // split_dim + 1
if matrix_g_dim % split_dim == 0:
batch_h = matrix_g_dim // split_dim
else:
if matrix_g_dim < split_dim:
batch_h = 1
split_dimg = matrix_g_dim
else:
batch_h = matrix_g_dim // split_dim + 1
matrix_a_shape = (batch_h, batch_w, split_dima, split_dima)
matrix_g_shape = (batch_h, split_dimg, split_dimg)
return matrix_a_shape, matrix_g_shape
def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
"""get layer type for dense layer and conv layer"""
if subcell.weight.requires_grad:
if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
layertype_map.append(Other)
def find_net_layertype_recur(net, layertype_map):
"""get net layer type recursively."""
cells = net.name_cells()
for name in cells:
subcell = cells[name]
prefix = subcell.param_prefix
if subcell == net:
continue
elif isinstance(subcell, Conv2dThor):
layertype_map.append(Conv)
elif isinstance(subcell, DenseThor):
layertype_map.append(FC)
elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
layertype_map.append(Embedding)
elif isinstance(subcell, nn.LayerNorm):
layertype_map.append(LayerNorm)
elif isinstance(subcell, nn.BatchNorm2d):
if subcell.gamma.requires_grad:
layertype_map.append(BatchNorm)
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
else:
layertype_map.append(Other)
else:
find_net_layertype_recur(subcell, layertype_map)
def get_net_layertype_mask(net):
layertype_map = []
find_net_layertype_recur(net, layertype_map)
return layertype_map
def get_layer_counter(layer_type, layer_counter, params, idx):
"""get layer counter"""
if layer_type in [Conv, FC]:
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
elif layer_type in [LayerNorm, BatchNorm]:
if "beta" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
elif "weight" in params[idx].name.lower():
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
else:
layer_counter = layer_counter + 1
return layer_counter
[docs]def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False,
frequency=100):
r"""
Updates gradients by second-order algorithm--THOR.
Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:
`THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation
<https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
A_i = a_i{a_i}^T \\
G_i = D_{s_i}{ D_{s_i}}^T \\
m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\
w_i = w_i - \alpha * m_i \\
\end{array}
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer,
:math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer,
:math:`\beta` represents momentum, :math:`I` represents the identity matrix,
:math:`\overline A` represents the transpose of matrix A,
:math:`\lambda` represents 'damping', :math:`g_i` represents gradients of the i-th layer,
:math:`\otimes` represents Kronecker product, :math:`\alpha` represents 'learning rate'
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
net (Cell): The training network.
learning_rate (Tensor): A value for the learning rate.
damping (Tensor): A value for the damping.
momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
default value. Default: 1.0.
batch_size (int): The size of a batch. Default: 32
use_nesterov (bool): Enable Nesterov momentum. Default: False.
decay_filter (function): A function to determine which layers the weight decay applied to. And it
only works when the weight_decay > 0. Default: lambda x: x.name not in []
split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other
is 27~53. Default: None
enable_clip_grad (bool): Whether to clip the gradients. Default: False
frequency(int): The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1),
A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and
$A^{-1}/G^{-1}$ to update weights. Default: 100.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Raises:
TypeError: If `learning_rate` is not Tensor.
TypeError: If `loss_scale`,`momentum` or `frequency` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_nesterov` is not a bool.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `weight_decay` or `momentum` is less than 0.
ValueError: If `frequency` is not int.
ValueError: If `frequency` is less than 2.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore.nn import thor
>>> from mindspore import Model
>>> from mindspore import FixedLossScaleManager
>>> from mindspore.train.callback import LossMonitor
>>> from mindspore.train.train_thor import ConvertModelUtils
>>> from mindspore import nn
>>> from mindspore import Tensor
>>>
>>> net = Net()
>>> dataset = create_dataset()
>>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
>>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> loss_scale = FixedLossScaleManager(128, drop_overflow_update=False)
>>> model = Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
... loss_scale_manager=loss_scale, metrics={'acc'},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> loss_cb = LossMonitor()
>>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True)
"""
context.set_context(max_call_depth=10000)
ConvertNetUtils().convert_to_thor_net(net)
if context.get_context("device_target") == "Ascend":
return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size,
use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad,
frequency=frequency)
class ThorGpu(Optimizer):
"""
ThorGpu
"""
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None,
enable_clip_grad=False, frequency=100):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.damping = damping
self._define_gpu_operator()
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
self.thor = True
self.matrix_a = ()
self.matrix_g = ()
self.matrix_a_shape = ()
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_fim_idx_map = ()
self.weight_conv_idx_map = ()
self.weight_layertype_idx_map = ()
self._process_matrix_init_and_weight_idx_map(self.net)
self.matrix_a = ParameterTuple(self.matrix_a)
self.matrix_g = ParameterTuple(self.matrix_g)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim)
self.enable_clip_grad = enable_clip_grad
self.frequency = frequency
self._define_gpu_reducer(split_indices)
def get_frequency(self):
"""get thor frequency"""
return self.frequency
def _define_gpu_operator(self):
"""define gpu operator"""
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.assign = P.Assign()
self.mul = P.Mul()
self.gather = P.GatherV2()
self.one = Tensor(1, mstype.int32)
self.feature_map = Tensor(1.0, mstype.float32)
self.axis = 0
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.cast = P.Cast()
self.sqrt = P.Sqrt()
self.eye = P.Eye()
self.split_dim = 128
self.embedding_cholesky = P.CholeskyTrsm()
self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.inv = P.Reciprocal()
self.square = P.Square()
self.expand = P.ExpandDims()
def _define_gpu_reducer(self, split_indices):
"""define gpu reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
if not split_indices:
self.split_indices = [len(self.matrix_a_cov) - 1]
else:
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for GPU, process matrix init shape, and get weight idx map"""
layer_type_map = get_net_layertype_mask(net)
layer_counter = 0
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_a_dim = in_channels
if layer_type == Conv:
matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_g_dim = out_channels
matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim)
matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,)
elif layer_type == Embedding:
vocab_size = weight_shape[0]
embedding_size = weight_shape[1]
matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),)
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.thor_layer_count = self.thor_layer_count + 1
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
else:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce):
"""get matrixA inverse list and matrix G inverse list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
if layer_type in [Conv, FC, Embedding]:
g = gradients[i]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
damping_a = damping_step
damping_g = damping_step
feature_map = self.feature_map
if layer_type == Conv:
a_normalizer = self.a_normalizer[conv_layer_count]
g_normalizer = self.g_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
g_normalizer = F.depend(g_normalizer, g)
damping_a = self.mul(damping_step, 1.0 / a_normalizer)
damping_g = self.mul(damping_step, 1.0 / g_normalizer)
feature_map = self.sqrt(1.0 / a_normalizer)
a_shape = self.shape(matrix_a)
a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
damping_a = self.sqrt(damping_a)
damping_g = self.sqrt(damping_g)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping_g * g_eye
if layer_type == Embedding:
a_eye = P.OnesLike()(matrix_a)
matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
matrix_a = matrix_a + damping_a * a_eye
matrix_a = self.inv(matrix_a)
matrix_g = self.embedding_cholesky(matrix_g)
matrix_g = self.matmul(matrix_g, matrix_g)
else:
matrix_a = matrix_a + damping_a * a_eye
matrix_a = self.cholesky(matrix_a)
matrix_a = self.vector_matmul(matrix_a, matrix_a)
matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a)
matrix_g = self.cholesky(matrix_g)
matrix_g = self.vector_matmul(matrix_g, matrix_g)
matrix_a = self.mul(matrix_a, feature_map)
matrix_g = self.mul(matrix_g, feature_map)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g,)
return matrix_a_allreduce, matrix_g_allreduce
def _process_layernorm(self, damping_step, gradient):
"""process layernorm"""
damping = self.sqrt(damping_step)
normalizer = self.batch_size
normalizer = self.cast(normalizer, mstype.float32)
fim_cov = self.square(gradient)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
gradient = self.mul(fim_inv, gradient)
return gradient
def _reshape_gradient(self, conv_layer_count, g, g_shape):
"""reshape gradient"""
if conv_layer_count != -1:
g = self.reshape(g, g_shape)
return g
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
new_grads = ()
if self.thor:
matrix_ainv_list = ()
matrix_ginv_list = ()
matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step,
matrix_ainv_list, matrix_ginv_list)
if self.is_distributed:
matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
for i in range(len(self.params)):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
if layer_type in [Conv, FC]:
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_a = matrix_a_allreduce[thor_layer_count]
matrix_g = matrix_g_allreduce[thor_layer_count]
g = self.update_gradient(matrix_g, g, matrix_a)
self.assign(self.matrix_a[thor_layer_count], matrix_a)
self.assign(self.matrix_g[thor_layer_count], matrix_g)
g = self._reshape_gradient(conv_layer_count, g, g_shape)
elif layer_type == Embedding:
matrix_a = matrix_a_allreduce[thor_layer_count]
matrix_g = matrix_g_allreduce[thor_layer_count]
self.assign(self.matrix_a[thor_layer_count], matrix_a)
self.assign(self.matrix_g[thor_layer_count], matrix_g)
temp_a = self.expand(matrix_a, 1)
g = self.mul(temp_a, g)
g = self.matmul(g, matrix_g)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
else:
for j in range(len(self.params)):
g = gradients[j]
thor_layer_count = self.weight_fim_idx_map[j]
conv_layer_count = self.weight_conv_idx_map[j]
layer_type = self.weight_layertype_idx_map[j]
if layer_type in [Conv, FC]:
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
g = self.update_gradient(matrix_g, g, matrix_a)
g = self._reshape_gradient(conv_layer_count, g, g_shape)
elif layer_type == Embedding:
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
g = gradients[j]
temp_a = self.expand(matrix_a, 1)
g = self.mul(temp_a, g)
g = self.matmul(g, matrix_g)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
gradients = new_grads
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
gradients = clip_gradient(self.enable_clip_grad, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success
class ThorAscend(Optimizer):
"""ThorAscend"""
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
self._define_ascend_operator()
self.C0 = 16
self.device_shape_pad_flag = ()
self.diag_block_dim = 128
self.matrix_a = ()
self.matrix_g = ()
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_conv_idx_map = ()
self.weight_fim_idx_map = ()
self.weight_layertype_idx_map = ()
self.a_split_pad_dim_map = ()
self.g_split_pad_dim_map = ()
self.conv_matmul_support_map = ()
self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36]
self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32]
self._process_matrix_init_and_weight_idx_map(self.net)
self.matrix_a = ParameterTuple(self.matrix_a)
self.matrix_g = ParameterTuple(self.matrix_g)
self.matrix_max_inv = ()
for i in range(len(self.matrix_a)):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.thor = True
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.damping = damping
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.enable_clip_grad = enable_clip_grad
self.frequency = frequency
self._define_ascend_reducer(split_indices)
def get_frequency(self):
"""get thor frequency"""
return self.frequency
def _get_pad_dim(self, matrix_dim):
"""get diag split pad dim """
split_pad_dim = 0
if matrix_dim == 64:
return split_pad_dim
res = matrix_dim % self.diag_block_dim
if res != 0:
split_pad_dim = self.diag_block_dim - res
return split_pad_dim
def _define_ascend_operator(self):
"""define ascend operator"""
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.gather = P.GatherV2()
self.assign = P.Assign()
self.cast = P.Cast()
self.eye = P.Eye()
self.concat = P.Concat(0)
self.cholesky = P.CusCholeskyTrsm()
self.vector_matmul = P.CusBatchMatMul()
self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True)
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.matrix_combine = P.CusMatrixCombine()
self.slice = P.Slice()
self.expand = P.ExpandDims()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.square = P.Square()
self.inv = P.Inv()
self.matmul = P.MatMul()
self.axis = 0
self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
def _define_ascend_reducer(self, split_indices):
"""define ascend reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
if not split_indices:
self.split_indices = [len(self.matrix_a_cov) - 1]
else:
self.split_indices = split_indices
if self.conv_layer_count > 0:
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
def _get_weight_idx_map(self, layer_type, idx, weight_shape):
"""for Ascend, get weight idx map"""
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Embedding:
a_pad_dim = 0
g_pad_dim = 0
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
else:
out_channels = weight_shape[0]
g_pad_dim = self._get_pad_dim(out_channels)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
matrix_a_dim = weight_shape[1]
if layer_type == Conv:
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
a_pad_dim = self._get_pad_dim(matrix_a_dim)
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
else:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
def _get_fc_matrix(self, weight_shape):
"""for Ascend, get fc matrix_a and matrix_g"""
out_channels = weight_shape[0]
in_channels = weight_shape[1]
if self.conv_layer_count > 0:
if out_channels == 1001:
fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
else:
fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
self.matrix_a = self.matrix_a + (fc_matrix_a,)
self.matrix_g = self.matrix_g + (fc_matrix_g,)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for Ascend, process matrix init shape, and get weight idx map"""
layer_counter = 0
layer_type_map = get_net_layertype_mask(net)
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type == Conv and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_g_dim = out_channels
matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True)
matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False)
ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape)
if ret:
matrix_a_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.conv_matmul_support_map = self.conv_matmul_support_map + (1,)
else:
matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.conv_matmul_support_map = self.conv_matmul_support_map + (0,)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
device_shape_pad_flag = False
if matrix_a_dim != matrix_a_device_dim:
device_shape_pad_flag = True
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
self._get_fc_matrix(weight_shape)
self._get_weight_idx_map(layer_type, idx, weight_shape)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
def _process_batch_matmul(self, input_matrix):
"""process batch matmul"""
input_matrix_shape = self.shape(input_matrix)
if input_matrix_shape[0] in self.batch_matmul_support_list:
input_matrix = self.vector_matmul(input_matrix, input_matrix)
else:
input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix)
return input_matrix
def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0):
"""process cholesky pad"""
if pad_dim > 0:
matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32)
matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup)
input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix)
input_matrix = self.concat((input_matrix, matrix_sup))
return input_matrix
def _get_abs_max(self, matrix_inv, origin_dim):
"""get matrix abs max"""
cholesky_shape = self.shape(matrix_inv)
if cholesky_shape[0] in self.abs_max_support_list:
matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv)
matrix_max = self.fused_abs_max2(matrix_inv_max)
matrix_inv = self.matrix_combine(matrix_inv)
else:
matrix_inv = self.matrix_combine(matrix_inv)
matrix_abs = P.Abs()(matrix_inv)
matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs)
return matrix_max, matrix_inv
def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get fc layer ainv and ginv"""
thor_layer_count = self.weight_fim_idx_map[index]
g = gradients[index]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
a_shape = self.shape(matrix_a)
a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
damping = self.sqrt(damping_step)
matrix_a = matrix_a + damping * a_eye
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
matrix_a_inv = self.cholesky(matrix_a)
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
weight_shape = self.shape(self.params[index])
out_channels = weight_shape[0]
in_channels = weight_shape[1]
if out_channels == 2:
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_g_inv = g_eye
else:
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping * g_eye
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
if self.conv_layer_count > 0:
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels)
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
a_max = F.depend(a_max, g)
g_max = F.depend(g_max, g)
matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
else:
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
if a_pad_dim > 0:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels))
if g_pad_dim > 0:
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
matrix_a_inv_shape = self.shape(matrix_a_inv)
matrix_g_combine_shape = self.shape(matrix_g_inv)
if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
matrix_a_inv = self.reshape(matrix_a_inv,
(matrix_a_inv_shape[0] / 16, 16,
matrix_a_inv_shape[0] / 16, 16))
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
matrix_g_inv_shape = self.shape(matrix_g_inv)
matrix_g_inv = self.reshape(matrix_g_inv,
(matrix_g_inv_shape[0] / 16, 16,
matrix_g_inv_shape[0] / 16, 16))
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv):
"""process conv matmul device pad"""
if self.device_shape_pad_flag[conv_layer_count]:
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_a_inv)
return matrix_a_inv
def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
weight_shape = self.shape(self.params[i])
out_channels = weight_shape[0]
if layer_type == Conv:
g = gradients[i]
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
a_shape = self.shape(matrix_a)
a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
a_normalizer = self.a_normalizer[conv_layer_count]
g_normalizer = self.g_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
g_normalizer = F.depend(g_normalizer, g)
damping_a = self.mul(damping_step, self.batch_size / a_normalizer)
damping_g = self.mul(damping_step, self.batch_size / g_normalizer)
damping_a = self.sqrt(damping_a)
matrix_a = matrix_a + damping_a * a_eye
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
matrix_a_inv = self.cholesky(matrix_a)
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim)
damping_g = self.sqrt(damping_g)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping_g * g_eye
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
if a_pad_dim > 0:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim))
if g_pad_dim > 0:
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
if matmul_support_flag == 1:
matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv)
matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
matrix_g_inv_shape[1], matrix_g_inv_shape[3])
matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape)
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
a_max = F.depend(a_max, g)
g_max = F.depend(g_max, g)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
elif layer_type == FC:
matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce)
elif layer_type == Embedding:
g = gradients[i]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
matrix_g = F.depend(matrix_g, g)
g_shape = self.shape(matrix_g)
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
damping = self.sqrt(damping_step)
a_eye = P.OnesLike()(matrix_a)
matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
matrix_a = matrix_a + damping * a_eye
matrix_a_inv = self.inv(matrix_a)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping * g_eye
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
def _process_layernorm(self, damping_step, gradient):
"""process layernorm layer for thor"""
damping = self.sqrt(damping_step)
normalizer = self.cast(self.batch_size, mstype.float32)
fim_cov = self.square(gradient)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
gradient = self.mul(fim_inv, gradient)
return gradient
def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g):
"""process thor graph fc layer"""
temp_a = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count]
self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
return g
def _get_second_gradients_one(self, params_len, gradients, new_grads):
"""get second gradients one"""
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
grad_shape = self.shape(g)
if layer_type == FC:
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(matrix_g, g)
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
else:
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
elif layer_type == Conv:
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(matrix_g, g)
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
g = self.reshape(g, grad_shape)
new_grads = new_grads + (g,)
return new_grads
def _get_second_gradients(self, new_grads, damping_step, gradients):
"""get second gradients for thor"""
params_len = len(self.params)
if self.conv_layer_count > 0:
new_grads = self._get_second_gradients_one(params_len, gradients, new_grads)
else:
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
if layer_type == Embedding:
temp_a_ori = self.matrix_a_cov[thor_layer_count]
temp_g = self.matrix_g_cov[thor_layer_count]
temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
elif layer_type == FC:
temp_a = self.matrix_a_cov[thor_layer_count]
temp_g = self.matrix_g_cov[thor_layer_count]
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
return new_grads
def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max):
"""get second gradient by matmul"""
conv_layer_count = self.weight_conv_idx_map[index]
layer_type = self.weight_layertype_idx_map[index]
grad_shape = self.shape(g)
if layer_type == FC:
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
else:
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
elif layer_type == Conv:
a_normalizer = self.a_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
temp_max = self.mul(temp_max, self.batch_size / a_normalizer)
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
g = self.reshape(g, grad_shape)
return g, temp_max
def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step):
"""get second gradient by layertype"""
thor_layer_count = self.weight_fim_idx_map[index]
layer_type = self.weight_layertype_idx_map[index]
if layer_type == Embedding:
temp_a_ori = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count]
self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
elif layer_type == FC:
g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
return g
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
if self.thor:
matrix_a_allreduce = ()
matrix_g_allreduce = ()
matrix_a_max_allreduce = ()
matrix_g_max_allreduce = ()
matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce)
if self.is_distributed:
matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
if self.conv_layer_count > 0:
matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce)
matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce)
new_grads = ()
if self.conv_layer_count > 0:
for i in range(len(self.params)):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
temp_a = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count]
matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count])
matrix_a_inv_max = self.mul(matrix_a_inv_max, -1)
matrix_a_inv_max = self.exp(matrix_a_inv_max)
temp_a = self.mul(temp_a, matrix_a_inv_max)
matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count])
matrix_g_inv_max = self.mul(matrix_g_inv_max, -1)
matrix_g_inv_max = self.exp(matrix_g_inv_max)
temp_g = self.mul(temp_g, matrix_g_inv_max)
temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count],
matrix_g_max_allreduce[thor_layer_count])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max)
self.assign(self.matrix_a[thor_layer_count], temp_a)
self.assign(self.matrix_g[thor_layer_count], temp_g)
self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
new_grads = new_grads + (g,)
gradients = new_grads
else:
for i in range(len(self.params)):
g = gradients[i]
g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step)
new_grads = new_grads + (g,)
gradients = new_grads
else:
new_grads = ()
gradients = self._get_second_gradients(new_grads, damping_step, gradients)
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
gradients = clip_gradient(self.enable_clip_grad, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success