mindspore.ops.bidense

mindspore.ops.bidense(input1, input2, weight, bias=None)[源代码]

对输入 input1input2 应用双线性全连接操作。双线性全连接函数定义如下,

\[output = x_{1}^{T}Ax_{2} + b\]

其中, \(x_{1}\) 代表 input1\(x_{2}\) 代表 input2\(A\) 代表 weight\(b\) 代表 bias

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • input1 (Tensor) - 输入Tensor,shape是 \((*, in1\_channels)\) ,其中 \(*\) 表示任意的附加维度,除最后一维外的维度与 input2 保持一致。

  • input2 (Tensor) - 输入Tensor,shape是 \((*, in2\_channels)\) ,其中 \(*\) 表示任意的附加维度,除最后一维外的维度与 input1 保持一致。

  • weight (Tensor) - 输入Tensor的权重,shape是 \((out\_channels, in1\_channels, in2\_channels)\)

  • bias (Tensor,可选) - 添加在输出结果的偏差,shape是 \((out\_channels)\)\(()\) 。默认值:None ,偏差为0。

返回:

Tensor,shape是 \((*, out\_channels)\) ,其中 \(*\) 表示任意的附加维度。输出Tensor除最后一维外其他维度与所有输入Tensor保持一致。

异常:
  • TypeError - input1 不是Tensor。

  • TypeError - input2 不是Tensor。

  • TypeError - weight 不是Tensor。

  • TypeError - bias 不是Tensor。

  • ValueError - 如果除了最后一维,input1 其他维度与 input2 有不同。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> input1 = mindspore.Tensor([[-1.1283, 1.2603],
...                            [0.0214, 0.7801],
...                            [-1.2086, 1.2849]], mindspore.float32)
>>> input2 = mindspore.Tensor([[-0.4631, 0.3238, 0.4201],
...                            [0.6215, -1.0910, -0.5757],
...                            [-0.7788, -0.0706, -0.7942]], mindspore.float32)
>>> weight = mindspore.Tensor([[[-0.3132, 0.9271, 1.1010],
...                             [0.6555, -1.2162, -0.2987]],
...                            [[1.0458, 0.5886, 0.2523],
...                             [-1.3486, -0.8103, -0.2080]],
...                            [[1.1685, 0.5569, -0.3987],
...                             [-0.4265, -2.6295, 0.8535]],
...                            [[0.6948, -1.1288, -0.6978],
...                             [0.3511, 0.0609, -0.1122]]], mindspore.float32)
>>> output = ops.bidense(input1, input2, weight)
>>> print(output)
[[-2.0612743 0.5581219 0.22383511 0.8667302]
 [1.4476739 0.12626505 1.6552988 0.21297503]
 [0.6003161 2.912046 0.5590313 -0.35449564]]