mindspore.dataset.transforms.TypeCast

查看源文件
class mindspore.dataset.transforms.TypeCast(data_type)[源代码]

将输入的Tensor转换为指定的数据类型。

说明

此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。

参数:
  • data_type (Union[mindspore.dtype, numpy.dtype]) - 指定要转换的数据类型。

异常:
支持平台:

CPU GPU Ascend

样例:

>>> import numpy as np
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms as transforms
>>> from mindspore import dtype as mstype
>>>
>>> # Use the transform in dataset pipeline mode
>>> # Generate 1d int numpy array from 0 - 63
>>> def generator_1d():
...     for i in range(64):
...         yield (np.array([i]),)
>>>
>>> generator_dataset = ds.GeneratorDataset(generator_1d, column_names='col')
>>> type_cast_op = transforms.TypeCast(mstype.int32)
>>> generator_dataset = generator_dataset.map(operations=type_cast_op)
>>> for item in generator_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["col"].shape, item["col"].dtype)
...     break
(1,) int32
>>>
>>> # Use the transform in eager mode
>>> data = np.array([2.71606445312564e-03, 6.3476562564e-03]).astype(np.float64)
>>> output = transforms.TypeCast(np.float16)(data)
>>> print(output.shape, output.dtype)
(2,) float16