mindsponge.cell.TriangleMultiplication
- class mindsponge.cell.TriangleMultiplication(num_intermediate_channel, equation, layer_norm_dim, batch_size=None)[源代码]
三角乘法层。详细实现过程参考 TriangleMultiplication 。 氨基酸对ij之间的信息通过ij,ik,jk三条边的信息整合,将ik和jk的点乘结果信息添加到ij边。
- 参数:
num_intermediate_channel (float) - 中间通道的数量。
equation (str) - 三角形边顺序的爱因斯坦算符表示,分别对应于"incoming"和"outgoing"的边更新形式。 \((ikc,jkc->ijc, kjc,kic->ijc)\)。
layer_norm_dim (int) - 归一层的最后一维的长度。
batch_size (int) - 三角乘法中的batch size。默认值:
None
。
- 输入:
pair_act (Tensor) - pair_act。氨基酸对之间的信息,shape为 \((N_{res}, N_{res}, layer\_norm\_dim)\) 。
pair_mask (Tensor) - 三角乘法层矩阵的mask。shape为 \((N_{res}, N_{res})\) 。
index (Tensor) - 在循环中的索引,只会在有控制流的时候使用。
- 输出:
Tensor。三角乘法层中的pair_act。shape为 \((N_{res}, N_{res}, layer\_norm\_dim)\) 。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> from mindsponge.cell import TriangleMultiplication >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = TriangleMultiplication(num_intermediate_channel=64, ... equation="ikc,jkc->ijc", layer_norm_dim=64, batch_size=0) >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) >>> out = model(input_0, input_1, index=0) >>> print(out.shape) (256, 256, 64)