mindsponge.cell.MSAColumnGlobalAttention

View Source On Gitee
class mindsponge.cell.MSAColumnGlobalAttention(num_head, gating, msa_act_dim, batch_size=None, slice_num=0)[source]

MSA column global attention. Transpose MSA information at sequence axis and residue axis, then use GlobalAttention <https://www.mindspore.cn/mindsponge/docs/zh-CN/master/cell/mindsponge.cell.GlobalAttention.html> to do Attention between input sequences without dealing with the relationship between residues in sequence. Comparing with MSAColumnAttention, it uses GlobalAttention to deal with longer input sequence.

Reference:

Jumper et al. (2021) Suppl. Alg. 19 'MSAColumnGlobalAttention'.

Parameters
  • num_head (int) – The number of the attention heads.

  • gating (bool) – Indicator of if the attention is gated.

  • msa_act_dim (int) – The dimension of the msa_act.

  • batch_size (int) – The batch size of parameters in MSAColumnGlobalAttention, used in while control flow. Default: None.

  • slice_num (int) – The number of slices to be made to reduce memory. Default: 0

Inputs:
  • msa_act (Tensor) - Tensor of msa_act with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .

  • msa_mask (Tensor) - The mask for msa_act matrix with shape \((N_{seqs}, N_{res})\) .

  • index (Tensor) - The index of while loop, only used in case of while control flow. Default: "None".

Outputs:

Tensor, the float tensor of the msa_act of the layer with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindsponge.cell import MSAColumnGlobalAttention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = MSAColumnGlobalAttention(num_head=4, gating=True, msa_act_dim=64, batch_size=None)
>>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32)
>>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16)
>>> index = None
>>> msa_out = model(msa_act, msa_mask, index)
>>> print(msa_out.shape)
(4, 256, 64)