mindspore.dataset.transforms.OneHot

View Source On Gitee
class mindspore.dataset.transforms.OneHot(num_classes, smoothing_rate=0.0)[source]

Apply One-Hot encoding to the input labels.

For a 1-D input of shape \((*)\), an output of shape \((*, num_classes)\) will be returned, where the elements with index values equal to the input values will be set to 1, and the rest will be set to 0. If a label smoothing rate is specified, the element values are further smoothed to enhance generalization.

Parameters
  • num_classes (int) – Total number of classes. Must be greater than the maximum value of the input labels.

  • smoothing_rate (float, optional) – The amount of label smoothing. Must be between [0.0, 1.0]. Default: 0.0, no label smoothing.

Raises
  • TypeError – If num_classes is not of type int.

  • TypeError – If smoothing_rate is not of type float.

  • ValueError – If smoothing_rate is not in range of [0.0, 1.0].

  • RuntimeError – If input label is not of type int.

  • RuntimeError – If the dimension of the input label is not 1.

Supported Platforms:

CPU

Examples

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms as transforms
>>>
>>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
>>> mnist_dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir)
>>>
>>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9
>>> onehot_op = transforms.OneHot(num_classes=10)
>>> mnist_dataset = mnist_dataset.map(operations=onehot_op, input_columns=["label"])