mindspore.nn.GLU

查看源文件
class mindspore.nn.GLU(axis=- 1)[源代码]

门线性单元函数(Gated Linear Unit function)。

\[{GLU}(a, b)= a \otimes \sigma(b)\]

其中,\(a\) 表示输入Tensor的前一半元素,\(b\) 表示输入Tensor的另一半元素。 这里 \(\sigma\) 为sigmoid函数,\(\otimes\) 是Hadamard乘积。

参数:
  • axis (int) - 指定分割轴。数据类型为整型,默认值: -1 ,输入x的最后一维。

输入:
  • x (Tensor) - Tensor的shape为 \((\ast_1, N, \ast_2)\)* 表示任意数量的维度。

输出:

Tensor,数据类型与输入 x 相同,shape为 \((\ast_1, M, \ast_2)\),其中 \(M=N/2\)

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> m = ms.nn.GLU()
>>> input = ms.Tensor([[0.1,0.2,0.3,0.4],[0.5,0.6,0.7,0.8]])
>>> output = m(input)
>>> print(output)
[[0.05744425 0.11973753]
 [0.33409387 0.41398472]]