Source code for mindsponge.cell.msa

# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MSA"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindsponge.common.utils import _memory_reduce
from .basic import Attention, GlobalAttention
from .mask import MaskedLayerNorm

[docs]class MSARowAttentionWithPairBias(nn.Cell): r""" MSA row attention. Information from pair action value is made as the bias of the matrix of MSARowAttention, in order to update the state of MSA using pair information. Reference: `Jumper et al. (2021) Suppl. Alg. 7 'MSARowAttentionWithPairBias' <https://www.nature.com/articles/s41586-021-03819-2>`_. Args: num_head (int): The number of the attention head. key_dim (int): The dimension of the attention hidden layer. gating (bool): Indicator of if the attention is gated. msa_act_dim (int): The dimension of the msa_act. pair_act_dim (int): The dimension of the pair_act. batch_size (int): The batch size of parameters in MSA row attention, used in while control flow. Default: ``None``. slice_num (int): The number of slices to be made to reduce memory. Default: ``0``. Inputs: - **msa_act** (Tensor) - Tensor of msa_act with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . - **msa_mask** (Tensor) - The mask for MSA row attention matrix with shape :math:`(N_{seqs}, N_{res})` . - **pair_act** (Tensor) - Tensor of pair_act with shape :math:`(N_{res}, N_{res}, pair\_act\_dim)` . Data type is float. - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: ``None``. - **norm_msa_mask** (Tensor) - The mask of msa_act when to do layernorm with shape :math:`(N_{seqs}, N_{res})`, Default: ``None``. - **norm_pair_mask** (Tensor) - The mask of pair_act when to do layernorm with shape :math:`(N_{res}, N_{res})`, Default: ``None``. - **res_idx** (Tensor) - The residue index used to perform ROPE with shape :math:`(N_{res})`, Default: ``None``. Outputs: Tensor, the float tensor of the msa_act of the layer with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.cell import MSARowAttentionWithPairBias >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = MSARowAttentionWithPairBias(num_head=4, key_dim=4, gating=True, ... msa_act_dim=64, pair_act_dim=128, ... batch_size=None) >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) >>> pair_act = Tensor(np.ones((256, 256, 128)), mstype.float32) >>> index = None >>> msa_out = model(msa_act, msa_mask, pair_act, index) >>> print(msa_out.shape) (4, 256, 64) """ def __init__(self, num_head, key_dim, gating, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0): super(MSARowAttentionWithPairBias, self).__init__() self.num_head = num_head self.batch_size = batch_size self.matmul = P.MatMul(transpose_b=True) self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) self.msa_act_dim = msa_act_dim self.pair_act_dim = pair_act_dim self.batch_size = batch_size self.slice_num = slice_num self.idx = Tensor(0, mstype.int32) self.masked_layer_norm = MaskedLayerNorm() self._init_parameter() def construct(self, msa_act, msa_mask, pair_act, index=None, norm_msa_mask=None, norm_pair_mask=None, res_idx=None): '''construct''' if self.batch_size: query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) feat_2d_norm_gamma = P.Gather()(self.feat_2d_norm_gammas, index, 0) feat_2d_norm_beta = P.Gather()(self.feat_2d_norm_betas, index, 0) feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0) else: query_norm_gamma = self.query_norm_gammas query_norm_beta = self.query_norm_betas feat_2d_norm_gamma = self.feat_2d_norm_gammas feat_2d_norm_beta = self.feat_2d_norm_betas feat_2d_weight = self.feat_2d_weights q, k, _ = pair_act.shape input_bias = 1e9 * (msa_mask - 1.0) input_bias = P.ExpandDims()(P.ExpandDims()(input_bias, 1), 2) msa_act = self.masked_layer_norm(msa_act, query_norm_gamma, query_norm_beta, mask=norm_msa_mask) pair_act = self.masked_layer_norm(pair_act, feat_2d_norm_gamma, feat_2d_norm_beta, mask=norm_pair_mask) pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1])) nonbatched_bias = P.Transpose()(P.Reshape()(self.matmul(pair_act, feat_2d_weight), (q, k, self.num_head)), (2, 0, 1)) batched_inputs = (msa_act, input_bias) if res_idx is not None: nonbatched_inputs = (nonbatched_bias, res_idx) else: nonbatched_inputs = (index, nonbatched_bias) msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) return msa_act def _init_parameter(self): '''init parameter''' if self.batch_size: self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) self.feat_2d_norm_gammas = Parameter( Tensor(np.zeros([self.batch_size, self.pair_act_dim]), mstype.float32)) self.feat_2d_norm_betas = Parameter( Tensor(np.zeros([self.batch_size, self.pair_act_dim]), mstype.float32)) self.feat_2d_weights = Parameter( Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32)) else: self.query_norm_gammas = Parameter(Tensor(np.ones([self.msa_act_dim]), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros([self.msa_act_dim]), mstype.float32)) self.feat_2d_norm_gammas = Parameter(Tensor(np.ones([self.pair_act_dim]), mstype.float32)) self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.pair_act_dim]), mstype.float32)) self.feat_2d_weights = Parameter( Tensor(np.random.normal(scale=1 / np.sqrt(self.pair_act_dim), size=[self.num_head, self.pair_act_dim]), mstype.float32)) def _compute(self, msa_act, mask, index, nonbatched_bias): """ compute. Args: msa_act (Tensor): Tensor of msa_act. mask (Tensor): The mask for MSA row attention matrix. index (Tensor): The index of while loop, only used in case of while control flow. Default: ``None``. nonbatched_bias(Tensor): Tensor of non batched bias matrix. Outputs: - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. """ msa_act = self.attn_mod(msa_act, msa_act, mask, index, nonbatched_bias) return msa_act
[docs]class MSAColumnAttention(nn.Cell): """ MSA column-wise gated self attention. The column-wise attention lets the elements that belong to the same target residue exchange information. Reference: `Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" <https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf>`_. Args: num_head (int): The number of the heads. key_dim (int): The dimension of the input. gating (bool): Indicator of if the attention is gated. msa_act_dim (int): The dimension of the msa_act. The intermediate variable after MSA retrieving in AlphaFold. batch_size (int): The batch size of parameters in MSAColumnAttention, used in while control flow, Default: "None". slice_num (int): The number of slices to be made to reduce memory, Default: 0. Inputs: - **msa_act** (Tensor) - Tensor of msa_act. The intermediate variable after MSA retrieving in AlphaFold, shape :math:`[N_{seqs}, N_{res}, C_m]` . - **msa_mask** (Tensor) - The mask for MSAColumnAttention matrix, shape :math:`[N_{seqs}, N_{res}]`. - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". Outputs: Tensor, the float tensor of the msa_act of the layer, shape :math:`[N_{seqs}, N_{res}, C_m]`. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.cell import MSAColumnAttention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = MSAColumnAttention(num_head=8, key_dim=256, gating=True, ... msa_act_dim=256, batch_size=1, slice_num=0) >>> msa_act = Tensor(np.ones((512, 256, 256)), mstype.float32) >>> msa_mask = Tensor(np.ones((512, 256)), mstype.float32) >>> index = Tensor(0, mstype.int32) >>> attn_out = model(msa_act, msa_mask, index) >>> print(attn_out.shape) (512, 256, 256) """ def __init__(self, num_head, key_dim, gating, msa_act_dim, batch_size=None, slice_num=0): super(MSAColumnAttention, self).__init__() self.query_norm = MaskedLayerNorm() self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) self.batch_size = batch_size self.slice_num = slice_num self.msa_act_dim = msa_act_dim self.idx = Tensor(0, mstype.int32) self._init_parameter() def construct(self, msa_act, msa_mask, index=None): '''construct''' if self.batch_size: query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) else: query_norm_gamma = self.query_norm_gammas query_norm_beta = self.query_norm_betas msa_act = P.Transpose()(msa_act, (1, 0, 2)) msa_mask = P.Transpose()(msa_mask, (1, 0)) input_mask = 1e9 * (msa_mask - 1.) input_mask = P.ExpandDims()(P.ExpandDims()(input_mask, 1), 2) msa_act = self.query_norm(msa_act, query_norm_gamma, query_norm_beta) batched_inputs = (msa_act, input_mask) nonbatched_inputs = (index,) msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) msa_act = P.Transpose()(msa_act, (1, 0, 2)) return msa_act def _init_parameter(self): if self.batch_size: self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) else: self.query_norm_gammas = Parameter(Tensor(np.ones([self.msa_act_dim]), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros([self.msa_act_dim]), mstype.float32)) def _compute(self, msa_act, input_mask, index): '''compute''' msa_act = self.attn_mod(msa_act, msa_act, input_mask, index) return msa_act
[docs]class MSAColumnGlobalAttention(nn.Cell): r""" MSA column global attention. Transpose MSA information at sequence axis and residue axis, then use `GlobalAttention <https://www.mindspore.cn/mindsponge/docs/zh-CN/master/cell/mindsponge.cell.GlobalAttention.html>` to do Attention between input sequences without dealing with the relationship between residues in sequence. Comparing with MSAColumnAttention, it uses GlobalAttention to deal with longer input sequence. Reference: `Jumper et al. (2021) Suppl. Alg. 19 'MSAColumnGlobalAttention' <https://www.nature.com/articles/s41586-021-03819-2>`_. Args: num_head (int): The number of the attention heads. gating (bool): Indicator of if the attention is gated. msa_act_dim (int): The dimension of the msa_act. batch_size (int): The batch size of parameters in MSAColumnGlobalAttention, used in while control flow. Default: ``None``. slice_num (int): The number of slices to be made to reduce memory. Default: 0 Inputs: - **msa_act** (Tensor) - Tensor of msa_act with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . - **msa_mask** (Tensor) - The mask for msa_act matrix with shape :math:`(N_{seqs}, N_{res})` . - **index** (Tensor) - The index of while loop, only used in case of while control flow. Default: "None". Outputs: Tensor, the float tensor of the msa_act of the layer with shape :math:`(N_{seqs}, N_{res}, msa\_act\_dim)` . Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.cell import MSAColumnGlobalAttention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = MSAColumnGlobalAttention(num_head=4, gating=True, msa_act_dim=64, batch_size=None) >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) >>> index = None >>> msa_out = model(msa_act, msa_mask, index) >>> print(msa_out.shape) (4, 256, 64) """ def __init__(self, num_head, gating, msa_act_dim, batch_size=None, slice_num=0): super(MSAColumnGlobalAttention, self).__init__() self.attn_mod = GlobalAttention(num_head, gating, msa_act_dim, msa_act_dim, batch_size) self.query_norm = MaskedLayerNorm() self.batch_size = batch_size self.slice_num = slice_num self.msa_act_dim = msa_act_dim self.idx = Tensor(0, mstype.int32) self._init_parameter() def construct(self, msa_act, msa_mask, index=None): '''construct''' if self.batch_size: query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) msa_act = P.Transpose()(msa_act, (1, 0, 2)) msa_mask = P.Transpose()(msa_mask, (1, 0)) else: query_norm_gamma = self.query_norm_gammas query_norm_beta = self.query_norm_betas msa_act = P.Transpose()(msa_act, (1, 0, 2)) msa_mask = P.Transpose()(msa_mask, (1, 0)) input_mask = 1e9 * (msa_mask - 1.) input_mask = P.ExpandDims()(P.ExpandDims()(input_mask, 1), 2) msa_act = self.query_norm(msa_act, query_norm_gamma, query_norm_beta) msa_mask = P.ExpandDims()(msa_mask, -1) batched_inputs = (msa_act, msa_mask) nonbatched_inputs = (index,) msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) msa_act = P.Transpose()(msa_act, (1, 0, 2)) return msa_act def _init_parameter(self): '''init parameter''' if self.batch_size: self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) else: self.query_norm_gammas = Parameter(Tensor(np.ones((self.msa_act_dim)), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros((self.msa_act_dim)), mstype.float32)) def _compute(self, msa_act, msa_mask, index): """ compute. Args: msa_act (Tensor): Tensor of msa_act. msa_mask (Tensor): The mask for msa_act matrix. index (Tensor): The index of while loop, only used in case of while control flow. Default: ``None``. Outputs: - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. """ msa_act = self.attn_mod(msa_act, msa_act, msa_mask, index) return msa_act
class MSARowAttentionWithPairBiasContact(nn.Cell): '''MSA row attention''' def __init__(self, num_head, key_dim, gating, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0): super(MSARowAttentionWithPairBiasContact, self).__init__() self.num_head = num_head self.batch_size = batch_size self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) self.matmul = P.MatMul(transpose_b=True) self.attn_mod = Attention(num_head, key_dim, gating, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) self.msa_act_dim = msa_act_dim self.pair_act_dim = pair_act_dim self.batch_size = batch_size self.slice_num = slice_num self.idx = Tensor(0, mstype.int32) self.masked_layer_norm = MaskedLayerNorm() self._init_parameter() def construct(self, msa_act, msa_mask, pair_act, contact_act, contact_info_mask, index): '''construct''' query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) feat_2d_norm_gamma = P.Gather()(self.feat_2d_norm_gammas, index, 0) feat_2d_norm_beta = P.Gather()(self.feat_2d_norm_betas, index, 0) feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0) contact_norm_gamma = P.Gather()(self.contact_norm_gammas, index, 0) contact_norm_beta = P.Gather()(self.contact_norm_betas, index, 0) contact_weight = P.Cast()(P.Gather()(self.contact_weights, index, 0), mstype.float16) q, k, _ = pair_act.shape msa_mask = P.Cast()(msa_mask, mstype.float32) bias = 1e9 * (msa_mask - 1.0) bias = P.ExpandDims()(P.ExpandDims()(bias, 1), 2) msa_act, _, _ = self.norm(msa_act, query_norm_gamma, query_norm_beta) pair_act, _, _ = self.norm(pair_act, feat_2d_norm_gamma, feat_2d_norm_beta) pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1])) pair_act_bias = P.Transpose()(P.Reshape()(self.matmul(pair_act, feat_2d_weight), (q, k, self.num_head)), (2, 0, 1)) contact_act = P.Cast()(contact_act, mstype.float32) contact_act, _, _ = self.norm(contact_act, contact_norm_gamma, contact_norm_beta) contact_act = P.Cast()(contact_act, mstype.float16) contact_act = P.Reshape()(contact_act, (-1, contact_act.shape[-1])) contact_act_bias = P.Transpose()(P.Reshape()(self.matmul(contact_act, contact_weight), (q, k, self.num_head)), (2, 0, 1)) contact_act_bias = contact_act_bias * contact_info_mask[None, :, :] nonbatched_bias = pair_act_bias + contact_act_bias batched_inputs = (msa_act, bias) nonbatched_inputs = (index, nonbatched_bias) msa_act = _memory_reduce(self._compute, batched_inputs, nonbatched_inputs, self.slice_num) return msa_act def _init_parameter(self): '''init parameter''' self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32)) self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32)) self.feat_2d_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32)) self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32)) self.feat_2d_weights = Parameter( Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32)) self.contact_norm_gammas = Parameter(Tensor(np.ones([self.batch_size, 32,]), mstype.float32)) self.contact_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, 32,]), mstype.float32)) self.contact_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head, 32]), mstype.float32)) def _compute(self, msa_act, mask, index, nonbatched_bias): """ compute. Args: msa_act (Tensor): Tensor of msa_act. mask (Tensor): The mask for MSA row attention matrix. index (Tensor): The index of while loop, only used in case of while control flow. Default: ``None``. nonbatched_bias(Tensor): Tensor of non batched bias matrix. Outputs: - **msa_act** (Tensor)- Tensor, the float tensor of the msa_act of the attention layer. """ msa_act = self.attn_mod(msa_act, msa_act, mask, index, nonbatched_bias) return msa_act