mindspore.nn.Dropout

class mindspore.nn.Dropout(keep_prob=0.5, p=None, dtype=mstype.float32)[source]

Dropout layer for the input.

Dropout is a regularization method. The operator randomly sets some neurons output to 0 according to the probability of discarding the probability of discarding. During the reasoning, this layer returns the same Tensor as the x.

This technique is proposed in paper Dropout: A Simple Way to Prevent Neural Networks from Overfitting and proved to be effective to reduce over-fitting and prevents neurons from co-adaptation. See more details in Improving neural networks by preventing co-adaptation of feature detectors.

Note

  • Each channel will be zeroed out independently on every construct call.

  • Parameter keep_prob will be removed in a future version, please use parameter p instead. Parameter p means the probability of the element of the input tensor to be zeroed.

  • Parameter dtype will be removed in a future version. It is not recommended to define this parameter.

Parameters
  • keep_prob (float) – Deprecated. The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9, dropping out 10% of input neurons. Default: 0.5.

  • p (Union[float, int, None]) – The dropout rate, greater than or equal to 0 and less than 1. E.g. rate=0.9, dropping out 90% of input neurons. Default: None.

  • dtype (mindspore.dtype) – Data type of input. Default: mstype.float32.

Inputs:
  • x (Tensor) - The input of Dropout with data type of float16 or float32.

Outputs:

Tensor, output tensor with the same shape as the x.

Raises
  • TypeError – If keep_prob is not a float.

  • TypeError – If the dtype of p is not float or int.

  • TypeError – If dtype of x is not neither float16 nor float32.

  • ValueError – If keep_prob is not in range (0, 1].

  • ValueError – If p is not in range [0, 1).

  • ValueError – If length of shape of x is less than 1.

Supported Platforms:

Ascend GPU CPU

Examples

>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> net = nn.Dropout(p=0.2)
>>> net.set_train()
>>> output = net(x)
>>> print(output.shape)
(2, 2, 3)