mindspore.dataset.transforms.Mask

查看源文件
class mindspore.dataset.transforms.Mask(operator, constant, dtype=mstype.bool_)[源代码]

用给条件判断输入Tensor的内容,并返回一个掩码Tensor。Tensor中任何符合条件的元素都将被标记为True,否则为False。

参数:
  • operator (Relational) - 关系操作符,可以取值为 Relational.EQRelational.NERelational.LTRelational.GTRelational.LERelational.GE 。以 Relational.EQ 为例,将找出Tensor中与 constant 相等的元素。

  • constant (Union[str, int, float, bool]) - 与输入Tensor进行比较的基准值。

  • dtype (mindspore.dtype, 可选) - 生成的掩码Tensor的数据类型。默认值: mstype.bool_

异常:
支持平台:

CPU

样例:

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms as transforms
>>> from mindspore.dataset.transforms import Relational
>>>
>>> # Use the transform in dataset pipeline mode
>>> # Data before
>>> # |  col   |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------+
>>> data = [[1, 2, 3]]
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["col"])
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms.Mask(Relational.EQ, 2))
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["col"].shape, item["col"].dtype)
(3,) bool
>>> # Data after
>>> # |       col         |
>>> # +--------------------+
>>> # | [False,True,False] |
>>> # +--------------------+
>>>
>>> # Use the transform in eager mode
>>> data = [1, 2, 3]
>>> output = transforms.Mask(Relational.EQ, 2)(data)
>>> print(output.shape, output.dtype)
(3,) bool