mindsponge.cell.TriangleAttention
- class mindsponge.cell.TriangleAttention(orientation, num_head, key_dim, gating, layer_norm_dim, batch_size=None, slice_num=0)[source]
Triangle attention. for the detailed implementation process, refer to TriangleAttention.
The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, which is divided into three parts: projection, self-attention and output. Firstly, the amino acid pair is projected to obtain the q, k, v, and then through the classic multi-head self-attention mechanism, add the relationship between i, j, k triangle sides, finally output the result.
- Parameters
orientation (int) – Decide the dimension of Triangle attention, used as the starting and ending edge of self-attention.
num_head (int) – The number of the heads.
key_dim (int) – The dimension of the hidden layer.
gating (bool) – Indicator of if the attention is gated.
layer_norm_dim (int) – The dimension of the layer_norm.
batch_size (int) – The batch size of triangle attention, default: “None”.
slice_num (int) – The number of slices to be made to reduce memory, default: 0.
- Inputs:
pair_act (Tensor) - Tensor of pair_act. shape \((N_{res}, N_{res}, layer\_norm\_dim)\)
pair_mask (Tensor) - The mask for TriangleAttention matrix with shape. shape \((N_{res}, N_{res})\).
index (Tensor) - The index of while loop, only used in case of while control flow, Default: “None”.
mask (Tensor) - The mask of pair_act when to do layernorm with shape (N_{res}, N_{res}), Default: “None”.
- Outputs:
Tensor, the float tensor of the pair_act of the layer with shape \((N{res}, N{res}, layer\_norm\_dim)\).
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> from mindsponge.cell import TriangleAttention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = TriangleAttention(orientation="per_row", num_head=4, key_dim=64, gating=True, layer_norm_dim=64) >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) >>> out = model(input_0, input_1, index=0) >>> print(out.shape) (256, 256, 64)