mindsponge.cell.MSAColumnGlobalAttention
- class mindsponge.cell.MSAColumnGlobalAttention(num_head, gating, msa_act_dim, batch_size=None, slice_num=0)[source]
MSA column global attention. Transpose MSA information at sequence axis and residue axis, then use GlobalAttention <https://www.mindspore.cn/mindsponge/docs/zh-CN/master/cell/mindsponge.cell.GlobalAttention.html> to do Attention between input sequences without dealing with the relationship between residues in sequence. Comparing with MSAColumnAttention, it uses GlobalAttention to deal with longer input sequence.
- Parameters
num_head (int) – The number of the attention heads.
gating (bool) – Indicator of if the attention is gated.
msa_act_dim (int) – The dimension of the msa_act.
batch_size (int) – The batch size of parameters in MSAColumnGlobalAttention, used in while control flow. Default:
None
.slice_num (int) – The number of slices to be made to reduce memory. Default: 0
- Inputs:
msa_act (Tensor) - Tensor of msa_act with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .
msa_mask (Tensor) - The mask for msa_act matrix with shape \((N_{seqs}, N_{res})\) .
index (Tensor) - The index of while loop, only used in case of while control flow. Default: "None".
- Outputs:
Tensor, the float tensor of the msa_act of the layer with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> from mindsponge.cell import MSAColumnGlobalAttention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = MSAColumnGlobalAttention(num_head=4, gating=True, msa_act_dim=64, batch_size=None) >>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32) >>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16) >>> index = None >>> msa_out = model(msa_act, msa_mask, index) >>> print(msa_out.shape) (4, 256, 64)