mindspore.ops.select

mindspore.ops.select(cond, x, y)[源代码]

Returns the selected elements, either from input \(x\) or input \(y\), depending on the condition cond.

Given a tensor as input, this operation inserts a dimension of 1 at the dimension, it was invalid when both \(x\) and \(y\) are none. Keep in mind that the shape of the output tensor can vary depending on how many true values are in the input. Indexes are output in row-first order.

The conditional tensor acts as an optional compensation (mask), which determines whether the corresponding element / row in the output must be selected from \(x\) (if true) or \(y\) (if false) based on the value of each element.

It can be defined as:

\[\begin{split}out_i = \begin{cases} x_i, & \text{if } condition_i \\ y_i, & \text{otherwise} \end{cases}\end{split}\]

If condition is a vector, then \(x\) and \(y\) are higher-dimensional matrices, then it chooses to copy that row (external dimensions) from \(x\) and \(y\). If condition has the same shape as \(x\) and \(y\), you can choose to copy these elements from \(x\) and \(y\).

Inputs:
  • cond (Tensor[bool]) - The shape is \((x_1, x_2, ..., x_N, ..., x_R)\). The condition tensor, decides which element is chosen.

  • x (Union[Tensor, int, float]) - The shape is \((x_1, x_2, ..., x_N, ..., x_R)\). The first input tensor. If x is int or float, it will be cast to the type of int32 or float32, and broadcast to the same shape as y. One of x and y must be a Tensor.

  • y (Union[Tensor, int, float]) - The shape is \((x_1, x_2, ..., x_N, ..., x_R)\). The second input tensor. If y is int or float, it will be cast to the type of int32 or float32, and broadcast to the same shape as x. One of x and y must be a Tensor.

Outputs:

Tensor, has the same shape as cond. The shape is \((x_1, x_2, ..., x_N, ..., x_R)\).

Raises
  • TypeError – If x or y is not a Tensor, int or float.

  • ValueError – The shapes of inputs not equal.

Supported Platforms:

Ascend GPU CPU

Examples

>>> # 1) Both inputs are Tensor
>>> import mindspore
>>> from mindspore import Tensor, ops
>>>
>>> cond = Tensor([True, False])
>>> x = Tensor([2,3], mindspore.float32)
>>> y = Tensor([1,2], mindspore.float32)
>>> output = ops.select(cond, x, y)
>>> print(output)
[2. 2.]
>>> # 2) y is a float
>>> cond = Tensor([True, False])
>>> x = Tensor([2,3], mindspore.float32)
>>> y = 2.0
>>> output = ops.select(cond, x, y)
>>> print(output)
[2. 2.]