mindsponge.cell.GlobalAttention

View Source On Gitee
class mindsponge.cell.GlobalAttention(num_head, gating, input_dim, output_dim, batch_size=None)[source]

This is an implementation of global gated self attention in the paper Highly accurate protein structure prediction with AlphaFold. For this attention, the shape of the query tensor, key tensor and the value tensor should be the same.

Parameters
  • num_head (int) – The number of the heads.

  • gating (bool) – Indicator of if the attention is gated.

  • input_dim (int) – The last dimension length of the input tensor.

  • output_dim (int) – The last dimension length of the output tensor.

  • batch_size (int) – The batch size of parameters in attention, used in while control flow. Default: None.

Inputs:
  • q_data (Tensor) - The query tensor with shape (batch_size, seq_length, input_dim) with seq_length the sequence length.

  • m_data (Tensor) - The key/value tensor with shape (batch_size, seq_length, input_dim).

  • q_mask (Tensor) - A binary mask for q_data of shape (batch_size, seq_length, 1).

  • bias (Tensor) - Bias for the attention matrix. Default: None.

  • index (Tensor) - The index of while loop, only used in case of while control flow. Default: None.

Outputs:

Tensor, Output tensor of the GlobalAttention layer with shape (batch_size, seq_length, output_dim).

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindsponge.cell import GlobalAttention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = GlobalAttention(num_head=4, input_dim=64, gating=True, output_dim=256)
>>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32)
>>> m_data = Tensor(np.ones((32, 128, 64)), mstype.float32)
>>> q_mask = Tensor(np.ones((32, 128, 1)), mstype.float32)
>>> attn_out= model(q_data, m_data, q_mask)
>>> print(attn_out.shape)
(32, 128, 256)