mindspore.ops.Select
- class mindspore.ops.Select[源代码]
根据条件判断Tensor中的元素的值,决定输出中的相应元素是从 x (如果元素值为True)还是从 y (如果元素值为False)中选择。
该算法可以被定义为:
\[\begin{split}out_i = \begin{cases} x_i, & \text{if } condition_i \\ y_i, & \text{otherwise} \end{cases}\end{split}\]- 输入:
condition (Tensor[bool]) - 条件Tensor,决定选择哪一个元素,shape是 \((x_1, x_2, ..., x_N, ..., x_R)\)。
x (Tensor) - 第一个被选择的Tensor,shape是 \((x_1, x_2, ..., x_N, ..., x_R)\)。
y (Tensor) - 第二个被选择的Tensor,shape是 \((x_1, x_2, ..., x_N, ..., x_R)\)。
- 输出:
Tensor,具有与输入 condition 相同的shape。
- 异常:
TypeError - 如果 x 或者 y 不是Tensor。
ValueError - 如果三个输入的shape不一致。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> select = ops.Select() >>> input_cond = Tensor([True, False]) >>> input_x = Tensor([2,3], mindspore.float32) >>> input_y = Tensor([1,2], mindspore.float32) >>> output = select(input_cond, input_x, input_y) >>> print(output) [2. 2.]