mindsponge.cell.OuterProductMean

查看源文件
class mindsponge.cell.OuterProductMean(num_outer_channel, act_dim, num_output_channel, batch_size=None, slice_num=0)[源代码]

通过外积平均计算输入Tensor(act)在第二维上的相关性,得到的相关性可以用于更新相关特征(如Pair特征)。

\[OuterProductMean(\mathbf{act}) = Linear(flatten(mean(\mathbf{act}\otimes\mathbf{act})))\]
参数:
  • num_outer_channel (float) - OuterProductMean中间层的通道数量。

  • act_dim (int) - 输入act的最后一维的长度。

  • num_output_channel (int) - 输出的通道数量。

  • batch_size (int) - OuterProductMean中的参数的batch size,应用while控制流时需要设置该变量, 默认值: None

  • slice_num (int) - 当内存超出上限时使用的切分数量。默认值: 0

输入:
  • act (Tensor) - 维度为 \((dim_1, dim_2, act\_dim)\)

  • mask (Tensor) - OuterProductMean的mask,shape为 \((dim_1, dim_2)\)

  • mask_norm (Tensor) - mask沿第一根轴的L2-norm的平方,预先计算避免在循环重复计算。shape为 \((dim_2, dim_2, 1)\)

  • index (Tensor) - 在循环中的索引。默认值: None

输出:

Tensor。OuterProductMean的输出,shape是 \((dim_2, dim_2, num\_output\_channel)\)

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> from mindsponge.cell import OuterProductMean
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> model = OuterProductMean(num_outer_channel=32, act_dim=128, num_output_channel=256)
>>> act = Tensor(np.ones((32, 64, 128)), mstype.float32)
>>> mask = Tensor(np.ones((32, 64)), mstype.float32)
>>> mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(mask, mask), -1)
>>> output= model(act, mask, mask_norm)
>>> print(output.shape)
(64, 64, 256)