mindspore.dataset.transforms.Mask
- class mindspore.dataset.transforms.Mask(operator, constant, dtype=mstype.bool_)[源代码]
用给条件判断输入Tensor的内容,并返回一个掩码Tensor。Tensor中任何符合条件的元素都将被标记为True,否则为False。
- 参数:
operator (
Relational
) - 关系操作符,可以取值为Relational.EQ
、Relational.NE
、Relational.LT
、Relational.GT
、Relational.LE
、Relational.GE
。以Relational.EQ
为例,将找出Tensor中与 constant 相等的元素。constant (Union[str, int, float, bool]) - 与输入Tensor进行比较的基准值。
dtype (
mindspore.dtype
, 可选) - 生成的掩码Tensor的数据类型。默认值:mstype.bool_
。
- 异常:
TypeError - 参数 operator 类型不为
mindspore.dataset.transforms.Relational
。TypeError - 参数 constant 类型不为str、int、float或bool。
TypeError - 参数 dtype 类型不为
mindspore.dtype
。
- 支持平台:
CPU
样例:
>>> import mindspore.dataset as ds >>> import mindspore.dataset.transforms as transforms >>> from mindspore.dataset.transforms import Relational >>> >>> # Use the transform in dataset pipeline mode >>> # Data before >>> # | col | >>> # +---------+ >>> # | [1,2,3] | >>> # +---------+ >>> data = [[1, 2, 3]] >>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["col"]) >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms.Mask(Relational.EQ, 2)) >>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... print(item["col"].shape, item["col"].dtype) (3,) bool >>> # Data after >>> # | col | >>> # +--------------------+ >>> # | [False,True,False] | >>> # +--------------------+ >>> >>> # Use the transform in eager mode >>> data = [1, 2, 3] >>> output = transforms.Mask(Relational.EQ, 2)(data) >>> print(output.shape, output.dtype) (3,) bool