mindspore.nn.Threshold

class mindspore.nn.Threshold(threshold, value)[源代码]

Threshold激活函数,按元素计算输出。

Threshold定义为:

\[\begin{split}y = \begin{cases} x, &\text{ if } x > \text{threshold} \\ \text{value}, &\text{ otherwise } \end{cases}\end{split}\]

参数:

  • threshold (Union[int, float]) - 阈值。

  • value (Union[int, float]) - 输入Tensor中element小于阈值时的填充值。

输入:

  • input_x (Tensor) - 输入Tensor,数据类型为float16或float32。

输出:

Tensor,数据类型和shape与 input_x 的相同。

异常:

  • TypeError - threshold 不是浮点数或整数。

  • TypeError - value 不是浮点数或整数。

支持平台:

Ascend CPU GPU

样例:

>>> import mindspore
>>> import mindspore.nn as nn
>>> m = nn.Threshold(0.1, 20)
>>> inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)
>>> outputs = m(inputs)
>>> print(outputs)
[ 20.0     0.2      0.3]