mindspore.nn.Threshold
- class mindspore.nn.Threshold(threshold, value)[源代码]
Threshold激活函数,按元素计算输出。
Threshold定义为:
\[\begin{split}y = \begin{cases} x, &\text{ if } x > \text{threshold} \\ \text{value}, &\text{ otherwise } \end{cases}\end{split}\]- 参数:
threshold (Union[int, float]) - 阈值。
value (Union[int, float]) - 输入Tensor中element小于阈值时的填充值。
- 输入:
input_x (Tensor) - 输入Tensor,数据类型为float16或float32。
- 输出:
Tensor,数据类型和shape与 input_x 的相同。
- 异常:
TypeError - threshold 不是浮点数或整数。
TypeError - value 不是浮点数或整数。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> import mindspore.nn as nn >>> m = nn.Threshold(0.1, 20) >>> inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32) >>> outputs = m(inputs) >>> print(outputs) [ 20.0 0.2 0.3]