mindspore.dataset.transforms.TypeCast
- class mindspore.dataset.transforms.TypeCast(data_type)[源代码]
将输入的Tensor转换为指定的数据类型。
说明
此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。
- 参数:
data_type (Union[mindspore.dtype, numpy.dtype]) - 指定要转换的数据类型。
- 异常:
TypeError - 当 data_type 的类型不为
mindspore.dtype
或numpy.dtype
。
- 支持平台:
CPU
GPU
Ascend
样例:
>>> import numpy as np >>> import mindspore.dataset as ds >>> import mindspore.dataset.transforms as transforms >>> from mindspore import dtype as mstype >>> >>> # Use the transform in dataset pipeline mode >>> # Generate 1d int numpy array from 0 - 63 >>> def generator_1d(): ... for i in range(64): ... yield (np.array([i]),) >>> >>> generator_dataset = ds.GeneratorDataset(generator_1d, column_names='col') >>> type_cast_op = transforms.TypeCast(mstype.int32) >>> generator_dataset = generator_dataset.map(operations=type_cast_op) >>> for item in generator_dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... print(item["col"].shape, item["col"].dtype) ... break (1,) int32 >>> >>> # Use the transform in eager mode >>> data = np.array([2.71606445312564e-03, 6.3476562564e-03]).astype(np.float64) >>> output = transforms.TypeCast(np.float16)(data) >>> print(output.shape, output.dtype) (2,) float16