mindspore.mint.where

View Source On Gitee
mindspore.mint.where(condition, input, other) Tensor[source]

Selects elements from input or other based on condition and returns a tensor.

\[\begin{split}output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases}\end{split}\]
Parameters
  • condition (Tensor[bool]) – If true, yield input, otherwise yield other.

  • input (Union[Tensor, Scalar]) – When condition is true, values to select from.

  • other (Union[Tensor, Scalar]) – When condition is false, values to select from.

Returns

Tensor, elements are selected from input and other.

Raises
  • TypeError – If condition is not a tensor.

  • TypeError – If both input and other are scalars.

  • ValueError – If condition, input and other can not broadcast to each other.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> from mindspore import tensor, mint
>>> from mindspore import dtype as mstype
>>> a = tensor(np.arange(4).reshape((2, 2)), mstype.float32)
>>> b = tensor(np.ones((2, 2)), mstype.float32)
>>> condition = a < 3
>>> output = mint.where(condition, a, b)
>>> print(output)
[[0. 1.]
 [2. 1.]]
mindspore.mint.where(condition) Tensor[source]

Identical to mindspore.ops.nonzero() with input condition and as_tuple being True.

Supported Platforms:

Ascend