mindsponge.cell.transition 源代码

# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transition"""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore import Parameter
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from .initializer import lecun_init
from .mask import MaskedLayerNorm
from ..common.utils import _memory_reduce


[文档]class Transition(nn.Cell): r""" This is 2-layer MLP where the intermediate layer expands number of channels of the input by a factor(num_intermediate_factor). .. math:: Transition(\mathbf{act}) = Linear(Linear(\mathbf{act})) Args: num_intermediate_factor(float): The expand factor of intermediate output channels compared to the input. input_dim(int): The channels of the input. batch_size(int): The batch size of parameters in Transition, used in while control flow. Default: "None". slice_num (int): The slice num used in transition layer when the memory is overflow. Default: 0. Inputs: - **act** (Tensor) - The input with channels equal to input_dim, shape is (..., input_dim). - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". - **mask** (Tensor) - The mask of act when to do layernorm with shape :math:`(32, input_{dim})`, Default: "None". Outputs: Tensor, the float tensor of the output of the layer with shape (..., input_dim). Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.cell import Transition >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = Transition(num_intermediate_factor=4, input_dim=128) >>> input = Tensor(np.ones((32, 128, 128)), mstype.float32) >>> output= model(input) >>> print(output.shape) (32, 128, 128) """ def __init__(self, num_intermediate_factor, input_dim, batch_size=None, slice_num=0): super(Transition, self).__init__() self.matmul = P.MatMul(transpose_b=True) self.input_dim = input_dim self.num_intermediate = int(input_dim * num_intermediate_factor) self.batch_size = batch_size self.slice_num = slice_num self.relu = nn.ReLU() self.idx = Tensor(0, mstype.int32) self.masked_layer_norm = MaskedLayerNorm() self._init_parameter() def construct(self, act, index=None, mask=None): '''Compute transition''' if self.batch_size: input_layer_norm_gamma = P.Gather()(self.input_layer_norm_gammas, index, 0) input_layer_norm_beta = P.Gather()(self.input_layer_norm_betas, index, 0) transition1_weight = P.Gather()(self.transition1_weights, index, 0) transition1_bias = P.Gather()(self.transition1_biases, index, 0) transition2_weight = P.Gather()(self.transition2_weights, index, 0) transition2_bias = P.Gather()(self.transition2_biases, index, 0) else: input_layer_norm_gamma = self.input_layer_norm_gammas input_layer_norm_beta = self.input_layer_norm_betas transition1_weight = self.transition1_weights transition1_bias = self.transition1_biases transition2_weight = self.transition2_weights transition2_bias = self.transition2_biases act = self.masked_layer_norm(act, input_layer_norm_gamma, input_layer_norm_beta, mask=mask) batched_inputs = (act,) nonbatched_inputs = (transition1_weight, transition1_bias, transition2_weight, transition2_bias) act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) return act def _init_parameter(self): '''init parameter''' if self.batch_size: self.input_layer_norm_gammas = Parameter( Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) self.input_layer_norm_betas = Parameter( Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) self.transition1_weights = Parameter( Tensor(np.zeros((self.batch_size, self.num_intermediate, self.input_dim)), mstype.float32)) self.transition1_biases = Parameter( Tensor(np.zeros((self.batch_size, self.num_intermediate)), mstype.float32)) self.transition2_weights = Parameter( Tensor(np.zeros((self.batch_size, self.input_dim, self.num_intermediate)), mstype.float32)) self.transition2_biases = Parameter( Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) else: self.input_layer_norm_gammas = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32)) self.input_layer_norm_betas = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) self.transition1_weights = Parameter(initializer(lecun_init(self.input_dim, initializer_name='relu'), [self.num_intermediate, self.input_dim])) self.transition1_biases = Parameter(Tensor(np.zeros((self.num_intermediate)), mstype.float32)) self.transition2_weights = Parameter( Tensor(np.zeros((self.input_dim, self.num_intermediate)), mstype.float32)) self.transition2_biases = Parameter(Tensor(np.zeros((self.input_dim)), mstype.float32)) def _compute(self, act, transition1_weight, transition1_bias, transition2_weight, transition2_bias): '''compute transition.''' act_shape = P.Shape()(act) if len(act_shape) != 2: act = P.Reshape()(act, (-1, act_shape[-1])) act = self.relu(P.BiasAdd()(self.matmul(act, transition1_weight), transition1_bias)) act = P.BiasAdd()(self.matmul(act, transition2_weight), transition2_bias) act = P.Reshape()(act, act_shape) return act