# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dim_reduce"""
from __future__ import absolute_import
import math
import numpy as np
from mindspore.nn.cell import Cell
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
__all__ = ["DimReduce"]
_scale_grad = C.MultitypeFuncGraph("_scale_grad")
@_scale_grad.register("Tensor", "Tensor")
def _scale_grad_process(scale, grad):
grad = F.cast(grad, mstype.float32)
grad = P.Div()(grad, scale)
return grad
_save_weight = C.MultitypeFuncGraph("_save_weight")
@_save_weight.register("Tensor", "Tensor")
def _save_weight_process(parameter, new_parameter):
return P.Assign()(parameter, new_parameter)
_pca_projection = C.MultitypeFuncGraph("_pca_projection")
@_pca_projection.register("Tensor", "Tensor")
def _pca_projection_process(pca_mat, grad):
grad_k = P.MatMul()(pca_mat, F.reshape(grad, (-1, 1)))
return grad_k
_pca_back_projection = C.MultitypeFuncGraph("_pca_back_projection")
@_pca_back_projection.register("Tensor", "Tensor", "Tensor")
def _pca_back_projection_process(grad_k, pca_mat, grad):
grad_proj = P.MatMul()(F.transpose(pca_mat, (1, 0)), grad_k)
grad_proj_reshape = F.reshape(grad_proj, F.shape(grad))
return grad_proj_reshape
_update_grad_res_momentum = C.MultitypeFuncGraph("_update_grad_res_momentum")
@_update_grad_res_momentum.register("Float32", "Float32", "Tensor", "Tensor", "Tensor")
def _update_grad_res_momentum_process(gamma, alpha, grad_res_momentum, grad, grad_proj):
grad_res_momentum_new = gamma * grad_res_momentum + grad - grad_proj
P.Assign()(grad_res_momentum, grad_res_momentum_new)
res = alpha * grad_res_momentum_new
return res
_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight")
@_get_delta_weight.register("Tensor", "Tensor", "Tensor")
def _get_delta_weight_process(rho, dn, grad_res_momentum):
delta_weight = grad_res_momentum - rho * dn
return delta_weight
[docs]class DimReduce(Cell):
r"""
The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models.
.. math::
\begin{align}
grad\_k &= pca\_mat \cdot grad\\
dk &= - bk \cdot grad\_k\\
sk &= rho ^ m \cdot dk\\
delta\_loss &= sigma \cdot grad\_k.T \cdot sk
\end{align}
Here:
- pca_mat (array): Shape (k*n), k is part of n_components, n is the size of weight.
- bk (array): Shape (k*k), is the symmetric positive definite matrix in Quasi-Newton method.
we need to find the m satisfy:
.. math::
new\_loss < old\_loss + delta\_loss
Then, get delta_grad to update the weights for model:
.. math::
\begin{align}
grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\
new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\
delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk
\end{align}
Args:
network (Cell): The training network. The network only supports single output.
optimizer (Union[Cell]): Optimizer for updating the weights.
weight (Tuple(Parameter)): Tuple of parameters.
pca_mat_local (numpy.ndarray): For PCA operation, k*n, k is part of n_components, n is the size of weight.
n_components (int): PCA.components.
rho (float): Coefficient.
gamma (float): Coefficient.
alpha (float): Coefficient.
sigma (float): Coefficient.
rank (int): Rank number.
rank_size (int): Rank size.
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors.
- **weight** (Tuple(Tensor)) - Tuple of parameters.
- **weight_clone** (Tuple(Tensor)) - clone of weight
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
Outputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, weight, pca_mat_local, n_components, rho, gamma, alpha, sigma, rank,
rank_size):
super(DimReduce, self).__init__()
self.network = network
self.optimizer = optimizer
self.rank = rank
self.rank_size = rank_size
self.gamma = gamma
self.alpha = alpha
self.sigma = sigma
self.float_type = mstype.float32
self._set_rho_list(rho)
self._set_local_pca_mat(pca_mat_local, n_components, weight)
self._set_init_parameter(weight)
self.hyper_map = C.HyperMap()
self.concat = P.Concat()
self.matmul = P.MatMul()
self.mul = P.Mul()
self.add = P.Add()
def construct(self, loss, old_grad, loss_scale, weight, weight_clone, *inputs):
gk, old_loss, gk_local = self._generate_gk(weight, loss, old_grad, loss_scale)
_save_weight(self.gk_last_back, self.gk_last)
_save_weight(self.bk_back, self.bk)
dk = self._apply_quasi_newton_update(gk)
if self.dk_pad_flag:
dk_pad = self.concat((dk, self.dk_pad_part))
else:
dk_pad = dk
dk_local = dk_pad[self.start_index: self.end_index, :]
dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad)
grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad)
dn = self.dn_init if self.rank_size > 1 else dn_local
grad_proj = self.grad_proj_init if self.rank_size > 1 else grad_proj_local
if self.rank_size > 1:
for broadcast in self.broadcast_list:
dn_part = broadcast(dn_local)
dn = self.hyper_map(self.add, dn, dn_part)
grad_proj_part = broadcast(grad_proj_local)
grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part)
rho, find = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
if not find:
_save_weight(self.gk_last, self.gk_last_back)
_save_weight(self.bk, self.bk_back)
clone = self._res_loss(old_grad, grad_proj, weight, weight_clone, rho, dn)
return F.depend(loss, clone)
def _set_rho_list(self, rho):
"""set rho list info."""
self.max_search_time = 2
self.rho_list = []
for i in range(self.max_search_time):
self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type))
self.rho_list.append(Tensor(0, dtype=self.float_type))
def _set_local_pca_mat(self, pca_mat_local, n_components, parameter_tuple):
"""set pca info."""
self.n_components = n_components
local_dim = math.ceil(self.n_components // self.rank_size)
self.start_index = self.rank * local_dim
self.end_index = (self.rank + 1) * local_dim
start = 0
self.pca_list_local = ()
for param in parameter_tuple:
size = np.shape(param.asnumpy().reshape((-1, 1)))[0]
self.pca_list_local += (Tensor(pca_mat_local[:, start:start + size], dtype=self.float_type),)
start += size
self.dk_pad_flag = False
pad_num = self.rank_size * local_dim - self.n_components
if pad_num:
self.dk_pad_flag = True
self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type)
if self.rank_size > 1:
self.broadcast_list = []
for i in range(self.rank_size):
broadcast = P.Broadcast(i)
self.broadcast_list.append(broadcast)
self.allreduce = P.AllReduce()
self.allgather = P.AllGather()
def _set_init_parameter(self, parameter_tuple):
"""init parameters."""
self.true_flag = Tensor(True)
self.false_flag = Tensor(False)
self.epsilon = np.power(10.0, -20)
self.gk_last = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="gk_last")
self.gk_last_init = Parameter(Tensor(False), name="gk_last_init")
self.bk = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk")
self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk")
self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type)
self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros")
self.gk_last_back = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type),
name="gk_last_back")
self.bk_back = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk_back")
self.grad_proj_init = ParameterTuple(parameter_tuple).clone(prefix="grad_proj_init", init="zeros")
self.dn_init = ParameterTuple(parameter_tuple).clone(prefix="dn_init", init="zeros")
def _res_loss(self, old_grad, grad_proj, weight, weight_clone, rho, dn):
"""update loss"""
update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha),
self.grad_res_momentum, old_grad, grad_proj)
delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad)
update = self.optimizer(delta_weight)
weight = F.depend(weight, update)
clone = self.hyper_map(_save_weight, weight_clone, weight)
return clone
def _generate_gk(self, weight, loss, old_grad, loss_scale):
"""generate gk"""
weight = F.depend(weight, loss)
old_grad = F.depend(old_grad, weight)
old_grad = self.hyper_map(F.partial(_scale_grad, loss_scale), old_grad)
old_loss = self.allreduce(loss) // self.rank_size if self.rank_size > 1 else loss
gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad)
gk_local = F.addn(gk_local)
gk_pad = self.allgather(gk_local) if self.rank_size > 1 else gk_local
gk_pad = F.reshape(gk_pad, (-1, 1))
gk = gk_pad[0:self.n_components, :]
return gk, old_loss, gk_local
def _line_search(self, gk, dk, dn, old_loss, weight, weight_clone, *inputs):
"""line search rho."""
res = self.rho_list[-1]
find = self.false_flag
for i in range(self.max_search_time):
find = self._find_rho(gk, dk, dn, old_loss, weight, weight_clone, self.rho_list[i], *inputs)
if find:
res = self.rho_list[i]
break
return res, find
def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs):
"""search rho."""
res = self.false_flag
sn = self.hyper_map(F.partial(self.mul, -1 * rho), dn)
sn = F.depend(sn, old_loss)
update = self.optimizer(sn)
new_loss = F.depend(self.network(*inputs), update)
if self.rank_size > 1:
new_loss = self.allreduce(new_loss) // self.rank_size
old_loss_delta = old_loss + self.sigma * rho * F.squeeze(self.matmul(F.transpose(gk, (1, 0)), dk))
if old_loss_delta > new_loss:
_save_weight(self.sk, rho * dk)
res = self.true_flag
weight_clone = F.depend(weight_clone, old_loss_delta)
restore = self.hyper_map(_save_weight, weight, weight_clone)
res = F.depend(res, restore)
return res
def _apply_quasi_newton_update(self, gk):
"""apply quasi_newton update."""
if self.gk_last_init:
yk = gk - self.gk_last
g = self.matmul(F.transpose(yk, (1, 0)), self.sk)
g = F.squeeze(g)
if g > self.epsilon:
pk = 1. / g
t1 = self.eye - self.matmul(pk * yk, F.transpose(self.sk, (1, 0)))
new_bk = self.matmul(self.matmul(F.transpose(t1, (1, 0)), self.bk), t1) + \
self.matmul(pk * self.sk, F.transpose(self.sk, (1, 0)))
_save_weight(self.bk, new_bk)
else:
_save_weight(self.gk_last_init, self.true_flag)
_save_weight(self.gk_last, gk)
dk = -1 * self.matmul(self.bk, gk)
return dk