
View Source On Gitee
class mindsponge.cell.MSARowAttentionWithPairBias(num_head, key_dim, gating, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0)[source]

MSA row attention. Information from pair action value is made as the bias of the matrix of MSARowAttention, in order to update the state of MSA using pair information.


Jumper et al. (2021) Suppl. Alg. 7 'MSARowAttentionWithPairBias'.

  • num_head (int) – The number of the attention head.

  • key_dim (int) – The dimension of the attention hidden layer.

  • gating (bool) – Indicator of if the attention is gated.

  • msa_act_dim (int) – The dimension of the msa_act.

  • pair_act_dim (int) – The dimension of the pair_act.

  • batch_size (int) – The batch size of parameters in MSA row attention, used in while control flow. Default: None.

  • slice_num (int) – The number of slices to be made to reduce memory. Default: 0.

  • msa_act (Tensor) - Tensor of msa_act with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .

  • msa_mask (Tensor) - The mask for MSA row attention matrix with shape \((N_{seqs}, N_{res})\) .

  • pair_act (Tensor) - Tensor of pair_act with shape \((N_{res}, N_{res}, pair\_act\_dim)\) . Data type is float.

  • index (Tensor) - The index of while loop, only used in case of while control flow. Default: None.

  • norm_msa_mask (Tensor) - The mask of msa_act when to do layernorm with shape \((N_{seqs}, N_{res})\), Default: None.

  • norm_pair_mask (Tensor) - The mask of pair_act when to do layernorm with shape \((N_{res}, N_{res})\), Default: None.

  • res_idx (Tensor) - The residue index used to perform ROPE with shape \((N_{res})\), Default: None.


Tensor, the float tensor of the msa_act of the layer with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .

Supported Platforms:

Ascend GPU


>>> import numpy as np
>>> from mindsponge.cell import MSARowAttentionWithPairBias
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = MSARowAttentionWithPairBias(num_head=4, key_dim=4, gating=True,
...                                     msa_act_dim=64, pair_act_dim=128,
...                                     batch_size=None)
>>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32)
>>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16)
>>> pair_act = Tensor(np.ones((256, 256, 128)), mstype.float32)
>>> index = None
>>> msa_out = model(msa_act, msa_mask, pair_act, index)
>>> print(msa_out.shape)
(4, 256, 64)