mindspore.nn.HSwish
- class mindspore.nn.HSwish[源代码]
Hard Swish激活函数。
对输入的每个元素计算Hard Swish。
Hard Swish定义如下:
\[\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},\]其中, \(x_i\) 是输入的元素。
输入:
x (Tensor) - 用于计算Hard Swish的Tensor。数据类型必须是float16或float32。shape为 \((N,*)\) ,其中 \(*\) 表示任意的附加维度数。
输出:
Tensor,具有与 x 相同的数据类型和shape。
异常:
TypeError - x 的数据类型既不是float16也不是float32。
- 支持平台:
GPU
CPU
样例:
>>> x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16) >>> hswish = nn.HSwish() >>> result = hswish(x) >>> print(result) [-0.3333 -0.3333 0. 1.667 0.6665]