mindspore.mint.masked_select

mindspore.mint.masked_select(input, mask)[源代码]

返回一个一维Tensor,其中的内容是 input 中对应于 mask 中True位置的值。mask 的shape与 input 的shape不需要一样,但必须符合广播规则。

参数:
  • input (Tensor) - 它的shape是 \((x_1, x_2, ..., x_R)\)

  • mask (Tensor[bool]) - 其为True的位置对应的 input 值将被保留。它的shape是 \((x_1, x_2, ..., x_R)\)

返回:

一个一维Tensor,类型与 input 相同。

异常:
  • TypeError - inputmask 不是Tensor。

  • TypeError - mask 不是bool类型的Tensor。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, mint
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
>>> output = mint.masked_select(x, mask)
>>> print(output)
[1 3]