Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

PR

Just a small problem.

I can fix it online!

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

mindsponge.cell.TriangleAttention

View Source On Gitee
class mindsponge.cell.TriangleAttention(orientation, num_head, key_dim, gating, layer_norm_dim, batch_size=None, slice_num=0)[source]

Triangle attention. for the detailed implementation process, refer to TriangleAttention.

The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, which is divided into three parts: projection, self-attention and output. Firstly, the amino acid pair is projected to obtain the q, k, v, and then through the classic multi-head self-attention mechanism, add the relationship between i, j, k triangle sides, finally output the result.

Parameters
  • orientation (int) – Decide the dimension of Triangle attention, used as the starting and ending edge of self-attention.

  • num_head (int) – The number of the heads.

  • key_dim (int) – The dimension of the hidden layer.

  • gating (bool) – Indicator of if the attention is gated.

  • layer_norm_dim (int) – The dimension of the layer_norm.

  • batch_size (int) – The batch size of triangle attention, default: “None”.

  • slice_num (int) – The number of slices to be made to reduce memory, default: 0.

Inputs:
  • pair_act (Tensor) - Tensor of pair_act. shape (Nres,Nres,layer_norm_dim)

  • pair_mask (Tensor) - The mask for TriangleAttention matrix with shape. shape (Nres,Nres).

  • index (Tensor) - The index of while loop, only used in case of while control flow, Default: “None”.

  • mask (Tensor) - The mask of pair_act when to do layernorm with shape (N_{res}, N_{res}), Default: “None”.

Outputs:

Tensor, the float tensor of the pair_act of the layer with shape (Nres,Nres,layer_norm_dim).

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindsponge.cell import TriangleAttention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = TriangleAttention(orientation="per_row", num_head=4, key_dim=64, gating=True, layer_norm_dim=64)
>>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32)
>>> input_1 = Tensor(np.ones((256, 256)), mstype.float32)
>>> out = model(input_0, input_1, index=0)
>>> print(out.shape)
(256, 256, 64)