Differences with torch.nn.Dropout

View Source On Gitee

torch.nn.Dropout

torch.nn.Dropout(p=0.5, inplace=False)

For more information, see torch.nn.Dropout.

mindspore.nn.Dropout

mindspore.nn.Dropout(keep_prob=0.5, p=None, dtype=mstype.float32)

For more information, see mindspore.nn.Dropout.

Differences

PyTorch: Dropout is a regularization device. The operator randomly sets some neuron outputs to 0 during training according to the dropout probability p , reducing overfitting by preventing correlation between neuron nodes.

MindSpore: MindSpore API implements much the same functionality as PyTorch. keep_prob is the input neuron retention rate, now deprecated, will be removed in the near future version. dtype sets the data type of the output Tensor, now deprecated.

Categories

Subcategories

PyTorch

MindSpore

Difference

Parameters

Parameter 1

-

keep_prob

MindSpore discard parameter

Parameter 2

p

p

The parameter names and functions are the same

Parameter 3

inplace

-

MindSpore does not have this parameter

Parameter 4

-

dtype

MindSpore discard parameter

Code Example

When the inplace input is False, both APIs achieve the same function.

# PyTorch
import torch
from torch import tensor
input = tensor([[1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
                [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
                [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
                [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
                [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00]])
output = torch.nn.Dropout(p=0.2, inplace=False)(input)
print(output.shape)
# torch.Size([5, 10])

# MindSpore
import mindspore
from mindspore import Tensor
x = Tensor([[1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
            [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
            [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
            [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00],
            [1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00]], mindspore.float32)
output = mindspore.nn.Dropout(p=0.2)(x)
print(output.shape)
# (5, 10)