Source code for mindsponge.cell.basic

# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""basic"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from .initializer import glorot_uniform


[docs]class Attention(nn.Cell): r""" This is an implementation of multihead attention in the paper `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the key with key length and the target length, the attention will be performed as the following. .. math:: Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias. if query, key and value tensor is same, then it will be modified version of self attention. Args: num_head(int): The number of the heads. hidden_size(int): The hidden size of the input. gating(bool): Indicator of if the attention is gated. q_data_dim(int): The last dimension length of the query tensor. m_data_dim(int): The last dimension length of the key and value tensor. output_dim(int): The last dimension length of the output tensor. batch_size(int): The batch size of parameters in attention, used in while control flow. Default: ``None``. Inputs: - **q_data** (Tensor) - The query tensor with shape (batch_size, query_seq_length, q_data_dim) with query_seq_length the query sequence length. - **m_data** (Tensor) - The key/value tensor with shape (batch_size, value_seq_length, m_data_dim) with value_seq_length the value sequence length. - **attention_mask** (Tensor) - The mask for attention matrix with shape (batch_size, num_head, query_seq_length, value_seq_length). - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: ``None``. - **nonbatched_bias** (Tensor) - Non-batched bias for the attention matrix with shape(num_heads, query_seq_length, value_seq_length). Default: ``None``. Outputs: Tensor, output tensor of the Attention layer with shape :math:`(batch_size, query_seq_length, hidden_size)`. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.cell import Attention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64, ... m_data_dim=64, output_dim=64) >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) >>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32) >>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32) >>> attn_out= model(q_data, m_data, attention_mask) >>> print(attn_out.shape) (32, 128, 64) """ def __init__(self, num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, batch_size=None): super(Attention, self).__init__() self.q_data_dim = q_data_dim self.m_data_dim = m_data_dim self.output_dim = output_dim self.num_head = num_head self.gating = gating self.hidden_size = hidden_size self.dim_per_head = self.hidden_size // self.num_head self.batch_size = batch_size self.matmul = P.MatMul(transpose_b=True) self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) self.softmax = nn.Softmax() self.sigmoid = nn.Sigmoid() self.batch_size = batch_size self._init_parameter() def construct(self, q_data, m_data, attention_mask, index=None, nonbatched_bias=None): '''construct''' if self.batch_size: linear_q_weight = P.Gather()(self.linear_q_weights, index, 0) linear_k_weight = P.Gather()(self.linear_k_weights, index, 0) linear_v_weight = P.Gather()(self.linear_v_weights, index, 0) linear_output_weight = P.Gather()(self.linear_output_weights, index, 0) o_bias = P.Gather()(self.o_biases, index, 0) linear_gating_weight = 0 gating_bias = 0 if self.gating: linear_gating_weight = P.Gather()(self.linear_gating_weights, index, 0) gating_bias = P.Gather()(self.gating_biases, index, 0) else: linear_q_weight = self.linear_q_weights linear_k_weight = self.linear_k_weights linear_v_weight = self.linear_v_weights linear_output_weight = self.linear_output_weights o_bias = self.o_biases linear_gating_weight = 0 gating_bias = 0 if self.gating: linear_gating_weight = self.linear_gating_weights gating_bias = self.gating_biases dim_b, dim_q, dim_a = q_data.shape _, dim_k, dim_c = m_data.shape dim_h = self.num_head q_data = P.Reshape()(q_data, (-1, dim_a)) m_data = P.Reshape()(m_data, (-1, dim_c)) q = self.matmul(q_data, linear_q_weight) * self.dim_per_head ** (-0.5) k = self.matmul(m_data, linear_k_weight) v = self.matmul(m_data, linear_v_weight) q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1)) k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1)) v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1)) tmp_q = P.Transpose()(q, (0, 2, 1, 3)) tmp_k = P.Transpose()(k, (0, 2, 1, 3)) logits = P.Add()(self.batch_matmul_trans_b(tmp_q, tmp_k), attention_mask) if nonbatched_bias is not None: bias = P.ExpandDims()(nonbatched_bias, 0) logits = P.Add()(logits, bias) weights = self.softmax(logits) tmp_v = P.Transpose()(v, (0, 2, 3, 1)) weighted_avg = P.Transpose()(self.batch_matmul_trans_b(weights, tmp_v), (0, 2, 1, 3)) if self.gating: gating_bias = P.ExpandDims()(P.ExpandDims()(gating_bias, 0), 0) gate_values = P.Add()(P.Reshape()(self.matmul(q_data, linear_gating_weight), (dim_b, dim_q, dim_h, -1)), gating_bias) gate_values = self.sigmoid(gate_values) weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1)) weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1)) output = P.Add()(P.Reshape()(self.matmul(weighted_avg, linear_output_weight), (dim_b, dim_q, -1)), P.ExpandDims()(o_bias, 0)) return output def _init_parameter(self): '''init parameter''' if self.batch_size: self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.dim_per_head, self.q_data_dim]), mstype.float32)) self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.dim_per_head, self.m_data_dim]), mstype.float32)) self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.dim_per_head, self.m_data_dim]), mstype.float32)) self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim, self.num_head * \ self.dim_per_head]), mstype.float32)) self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), mstype.float32)) if self.gating: self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * \ self.dim_per_head, self.q_data_dim]), mstype.float32)) self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_head, self.dim_per_head)), mstype.float32), name="gating_b") else: self.linear_q_weights = Parameter(Tensor( glorot_uniform(self.num_head * self.q_data_dim, self.dim_per_head * self.q_data_dim, [self.num_head * self.dim_per_head, self.q_data_dim]), mstype.float32)) self.linear_k_weights = Parameter(Tensor( glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, [self.num_head * self.dim_per_head, self.m_data_dim]), mstype.float32)) self.linear_v_weights = Parameter(Tensor( glorot_uniform(self.num_head * self.m_data_dim, self.dim_per_head * self.m_data_dim, [self.num_head * self.dim_per_head, self.m_data_dim]), mstype.float32)) self.linear_output_weights = Parameter( Tensor(np.zeros([self.output_dim, self.num_head * self.dim_per_head]), mstype.float32)) self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32)) if self.gating: self.linear_gating_weights = Parameter( Tensor(np.zeros([self.num_head * self.dim_per_head, self.q_data_dim]), mstype.float32)) self.gating_biases = Parameter(Tensor(np.ones((self.num_head, self.dim_per_head)), mstype.float32), name="gating_b")
[docs]class GlobalAttention(nn.Cell): r""" This is an implementation of global gated self attention in the paper `Highly accurate protein structure prediction with AlphaFold <https://www.nature.com/articles/s41586-021-03819-2.pdf>`_. For this attention, the shape of the query tensor, key tensor and the value tensor should be the same. Args: num_head(int): The number of the heads. gating(bool): Indicator of if the attention is gated. input_dim(int): The last dimension length of the input tensor. output_dim(int): The last dimension length of the output tensor. batch_size(int): The batch size of parameters in attention, used in while control flow. Default: ``None``. Inputs: - **q_data** (Tensor) - The query tensor with shape (batch_size, seq_length, input_dim) with seq_length the sequence length. - **m_data** (Tensor) - The key/value tensor with shape (batch_size, seq_length, input_dim). - **q_mask** (Tensor) - A binary mask for q_data of shape (batch_size, seq_length, 1). - **bias** (Tensor) - Bias for the attention matrix. Default: ``None``. - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: ``None``. Outputs: Tensor, Output tensor of the GlobalAttention layer with shape (batch_size, seq_length, output_dim). Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.cell import GlobalAttention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = GlobalAttention(num_head=4, input_dim=64, gating=True, output_dim=256) >>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32) >>> m_data = Tensor(np.ones((32, 128, 64)), mstype.float32) >>> q_mask = Tensor(np.ones((32, 128, 1)), mstype.float32) >>> attn_out= model(q_data, m_data, q_mask) >>> print(attn_out.shape) (32, 128, 256) """ def __init__(self, num_head, gating, input_dim, output_dim, batch_size=None): super(GlobalAttention, self).__init__() self.input_dim = input_dim self.num_head = num_head self.dim_per_head = self.input_dim // self.num_head self.output_dim = output_dim self.matmul_trans_b = P.MatMul(transpose_b=True) self.batch_matmul = P.BatchMatMul() self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) self.matmul = P.MatMul() self.softmax = nn.Softmax() self.sigmoid = nn.Sigmoid() self.gating = gating self.batch_size = batch_size self._init_parameter() def construct(self, q_data, m_data, q_mask, index=None): '''construct''' if self.batch_size: q_weights = P.Gather()(self.linear_q_weights, index, 0) k_weights = P.Gather()(self.linear_k_weights, index, 0) v_weights = P.Gather()(self.linear_v_weights, index, 0) output_weights = P.Gather()(self.linear_output_weights, index, 0) output_bias = P.Gather()(self.o_biases, index, 0) gating_weights = 0 gating_bias = 0 if self.gating: gating_weights = P.Gather()(self.linear_gating_weights, index, 0) gating_bias = P.Gather()(self.gating_biases, index, 0) else: q_weights = self.linear_q_weights k_weights = self.linear_k_weights v_weights = self.linear_v_weights output_weights = self.linear_output_weights output_bias = self.o_biases gating_weights = 0 gating_bias = 0 if self.gating: gating_weights = self.linear_gating_weights gating_bias = self.gating_biases b, _, _ = m_data.shape v_weights = P.BroadcastTo((b, self.dim_per_head * self.num_head, self.dim_per_head))(v_weights) v = self.batch_matmul(m_data, v_weights) mask_shape = q_mask.shape value_shape = q_data.shape broadcast_factor = 1. value_size = value_shape[1] mask_size = mask_shape[1] if mask_size == 1: broadcast_factor = broadcast_factor * value_size qa = P.ReduceSum()(q_mask * q_data, 1) qb = P.ReduceSum()(q_mask, 1) * broadcast_factor + 1e-10 q_avg = P.RealDiv()(qa, qb) q = P.Reshape()(self.matmul(q_avg, q_weights), (-1, self.num_head, self.dim_per_head)) * (self.dim_per_head ** (-0.5)) k_weights = P.BroadcastTo((b, self.dim_per_head * self.num_head, self.dim_per_head))(k_weights) k = self.batch_matmul(m_data, k_weights) attention_mask = 1e9 * (P.Transpose()(q_mask, (0, 2, 1)) - 1.0) logits = P.Add()(self.batch_matmul_trans_b(q, k), attention_mask) weights = self.softmax(logits) weighted_avg = self.batch_matmul(weights, v) if self.gating: q_data_shape = P.Shape()(q_data) if len(q_data_shape) != 2: q_data = P.Reshape()(q_data, (-1, q_data_shape[-1])) out_shape = q_data_shape[:-1] + (-1,) gate_values = P.Reshape()(self.matmul_trans_b(q_data, gating_weights) + gating_bias, out_shape) gate_values = P.Reshape()(self.sigmoid(gate_values), (b, -1, self.num_head, self.dim_per_head)) weighted_avg = P.Reshape()(P.ExpandDims()(weighted_avg, 1) * gate_values, (-1, self.num_head * self.dim_per_head)) weighted_avg_shape = P.Shape()(weighted_avg) if len(weighted_avg_shape) != 2: weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, output_weights), output_bias), (b, -1, self.output_dim)) else: weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.dim_per_head)) weighted_avg_shape = P.Shape()(weighted_avg) if len(weighted_avg_shape) != 2: weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) out_shape = weighted_avg_shape[:-1] + (-1,) output = P.Reshape()(P.Add()(self.matmul_trans_b(weighted_avg, output_weights), output_bias), out_shape) output = P.ExpandDims()(output, -1) return output def _init_parameter(self): '''init parameter''' if self.batch_size: self.linear_q_weights = Parameter( Tensor(np.zeros((self.batch_size, self.input_dim, self.num_head * self.dim_per_head)), mstype.float32)) self.linear_k_weights = Parameter( Tensor(np.zeros((self.batch_size, 1, self.input_dim, self.dim_per_head)), mstype.float32)) self.linear_v_weights = Parameter( Tensor(np.zeros((self.batch_size, 1, self.input_dim, self.dim_per_head)), mstype.float32)) self.linear_output_weights = Parameter( Tensor(np.zeros((self.batch_size, self.output_dim, self.num_head * self.dim_per_head)), mstype.float32)) self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), mstype.float32)) if self.gating: self.linear_gating_weights = Parameter( Tensor(np.zeros((self.batch_size, self.num_head * self.dim_per_head, self.input_dim)), mstype.float32)) self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.input_dim)), mstype.float32)) else: self.linear_q_weights = Parameter(Tensor( glorot_uniform(self.num_head * self.input_dim, self.dim_per_head * self.input_dim, (self.input_dim, self.num_head*self.dim_per_head)), mstype.float32)) self.linear_k_weights = Parameter( Tensor(glorot_uniform(self.input_dim, self.dim_per_head, (1, self.input_dim, self.dim_per_head)), mstype.float32)) self.linear_v_weights = Parameter( Tensor(glorot_uniform(self.input_dim, self.dim_per_head, (1, self.input_dim, self.dim_per_head)), mstype.float32)) self.linear_output_weights = Parameter( Tensor(np.zeros((self.output_dim, self.num_head * self.dim_per_head)), mstype.float32)) self.o_biases = Parameter(Tensor(np.zeros((self.output_dim)), mstype.float32)) if self.gating: self.linear_gating_weights = Parameter( Tensor(np.zeros((self.num_head * self.dim_per_head, self.input_dim)), mstype.float32)) self.gating_biases = Parameter(Tensor(np.ones((self.input_dim)), mstype.float32))