mindspore.ops.threshold

查看源文件
mindspore.ops.threshold(input, thr, value)[源代码]

使用阈值 thr 参数对 input 逐元素阈值化,并将其结果作为Tensor返回。

threshold定义为:

y={input, if input>thrvalue, otherwise 
参数:
  • input (Tensor) - 输入Tensor,数据类型为float16或float32。

  • thr (Union[int, float]) - 阈值。

  • value (Union[int, float]) - 输入Tensor中element小于阈值时的填充值。

返回:

Tensor,数据类型和shape与 input 的相同。

异常:
  • TypeError - input 不是Tensor。

  • TypeError - thr 不是浮点数或整数。

  • TypeError - value 不是浮点数或整数。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> inputs = mindspore.Tensor([0.0, 2, 3], mindspore.float32)
>>> outputs = ops.threshold(inputs, 1, 100)
>>> print(outputs)
[100.   2.   3.]