mindspore.ops.swiglu

mindspore.ops.swiglu(input, dim=- 1)[源代码]

计算Swish门线性单元函数(Swish Gated Linear Unit function)。 SwiGLU是 mindspore.ops.GLU 激活函数的变体,定义为:

警告

这是一个实验性API,后续可能修改或删除。

\[{SwiGLU}(a, b)= Swish(a) \otimes b\]

其中,\(a\) 表示输入 input 拆分后 Tensor的前一半元素,\(b\) 表示输入拆分Tensor的另一半元素, Swish(a)=a \(\sigma\) (a),\(\sigma\)mindspore.ops.sigmoid() 函数, \(\otimes\) 是Hadamard乘积。

参数:
  • input (Tensor) - 被分Tensor,shape为 \((\ast_1, N, \ast_2)\) ,其中 * 为任意额外维度。 \(N\) 必须能被2整除。

  • dim (int,可选) - 指定分割轴。数据类型为整型,默认值: -1 ,输入input的最后一维。

返回:

Tensor,数据类型与输入 input 相同,shape为 \((\ast_1, M, \ast_2)\),其中 \(M=N/2\)

异常:
  • TypeError - input 数据类型不是float16、float32或bfloat16。

  • TypeError - input 不是Tensor。

  • RuntimeError - dim 指定维度不能被2整除。

支持平台:

Ascend

样例:

>>> from mindspore import Tensor, ops
>>> input = Tensor([[-0.12, 0.123, 31.122], [2.1223, 4.1212121217, 0.3123]], dtype=mindspore.float32)
>>> output = ops.swiglu(input, 0)
>>> print(output)
[[-0.11970687 0.2690224 9.7194 ]]