mindspore.ops.Einsum
- class mindspore.ops.Einsum(equation)[源代码]
此算子使用爱因斯坦求和约定(Einsum)进行Tensor计算。支持对角线、约和、转置、矩阵乘、乘积、内积运算等。
输入必须是Tensor的tuple。当输入只有一个Tensor时,可以输入(Tensor, ),支持数据类型float16、float32、float64。
- 参数:
equation (str) - 属性,表示要执行的计算。该值只能使用letter([a-z][A-Z])、commas(,)、ellipsis(…)和arrow(->)。letter([a-z][A-Z])表示输入的Tensor的维度,commas(,)表示Tensor维度之间的分隔符,ellipsis(…)表示不关心的Tensor维度,arrow(->)的左侧表示输入Tensor,右侧表示所需的输出维度。
- 输入:
x (Tuple) - 用于计算的输入Tensor,Tensor的数据类型必须相同。
- 输出:
Tensor,shape可以从方程中获得,数据类型与输入Tensor相同。
- 异常:
TypeError - 如果 equation 本身无效,或者 equation 与输入Tensor不匹配。
- 支持平台:
GPU
样例:
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32) >>> equation = "i->" >>> einsum = ops.Einsum(equation) >>> output = einsum([x]) >>> print(output) [7.] >>> >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32) >>> y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32) >>> equation = "i,i->i" >>> einsum = ops.Einsum(equation) >>> output = einsum((x, y)) >>> print(output) [ 2. 8. 12.] >>> >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) >>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), mindspore.float32) >>> equation = "ij,jk->ik" >>> einsum = ops.Einsum(equation) >>> output = einsum((x, y)) >>> print(output) [[16. 22.] [37. 52.]] >>> >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) >>> equation = "ij->ji" >>> einsum = ops.Einsum(equation) >>> output = einsum((x,)) >>> print(output) [[1. 4.] [2. 5.] [3. 6.]] >>> >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) >>> equation = "ij->j" >>> einsum = ops.Einsum(equation) >>> output = einsum((x,)) >>> print(output) [5. 7. 9.] >>> >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) >>> equation = "...->" >>> einsum = ops.Einsum(equation) >>> output = einsum((x,)) >>> print(output) [21.] >>> >>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) >>> y = Tensor(np.array([2.0, 4.0, 1.0]), mindspore.float32) >>> equation = "j,i->ji" >>> einsum = ops.Einsum(equation) >>> output = einsum((x, y)) >>> print(output) [[ 2. 4. 1.] [ 4. 8. 2.] [ 6. 12. 3.]]