mindspore.nn.Threshold

class mindspore.nn.Threshold(threshold, value)[source]

Thresholds each element of the input Tensor.

The formula is defined as follows:

\[\begin{split}y = \begin{cases} x, &\text{ if } x > \text{threshold} \\ \text{value}, &\text{ otherwise } \end{cases}\end{split}\]
Parameters
  • threshold (Union[int, float]) – The value to threshold at.

  • value (Union[int, float]) – The value to replace with when element is less than threshold.

Inputs:
  • input_x (Tensor) - The input of Threshold with data type of float16 or float32.

Outputs:

Tensor, the same shape and data type as the input.

Raises
  • TypeError – If threshold is not a float or an int.

  • TypeError – If value is not a float or an int.

Supported Platforms:

Ascend CPU GPU

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> m = nn.Threshold(0.1, 20)
>>> inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)
>>> outputs = m(inputs)
>>> print(outputs)
[ 20.0     0.2      0.3]