mindspore.mint.nn.GELU

查看源文件
class mindspore.mint.nn.GELU[源代码]

高斯误差线性单元激活函数(Gaussian Error Linear Unit)。

GELU的描述可以在 Gaussian Error Linear Units (GELUs) 这篇文章中找到。 也可以去查询 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

更多参考详见 mindspore.mint.nn.functional.gelu()

GELU函数图:

../../_images/GELU.png
支持平台:

Ascend

样例:

>>> import mindspore
>>> from mindspore import Tensor, mint
>>> import numpy as np
>>> input = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> gelu = mint.nn.GELU()
>>> output = gelu(input)
>>> print(output)
[[-1.5880802e-01  3.9999299e+00 -3.1077917e-21]
 [ 1.9545976e+00 -2.2918017e-07  9.0000000e+00]]
>>> gelu = mint.nn.GELU(approximate=False)
>>> # CPU not support "approximate=False", using "approximate=True" instead
>>> output = gelu(input)
>>> print(output)
[[-1.5865526e-01  3.9998732e+00 -0.0000000e+00]
 [ 1.9544997e+00 -1.4901161e-06  9.0000000e+00]]