比较torch.nn.Dropout与mindspore.nn.Dropout的功能差异
torch.nn.Dropout
class torch.nn.Dropout(
p=0.5,
inplace=False
)
更多内容详见torch.nn.Dropout。
mindspore.nn.Dropout
class mindspore.nn.Dropout(
keep_prob=0.5,
dtype=mstype.float
)
更多内容详见mindspore.nn.Dropout。
使用方式
PyTorch中P参数为丢弃参数的概率。
MindSpore中keep_prob参数为保留参数的概率。
代码示例
# The following implements Dropout with MindSpore.
import torch.nn
import mindspore.nn
import numpy as np
m = torch.nn.Dropout(p=0.9)
input = torch.tensor(np.ones([5,5]),dtype=float)
output = m(input)
print(output)
# out:
# [[0 10 0 0 0]
# [0 0 0 0 0]
# [0 0 10 0 0]
# [0 10 0 0 0]
# [0 0 0 0 10]]
input = mindspore.Tensor(np.ones([5,5]),mindspore.float32)
net = mindspore.nn.Dropout(keep_prob=0.1)
net.set_train()
output = net(input)
print(output)
# out:
# [[0 10 0 0 0]
# [0 0 0 10 0]
# [0 0 0 0 0]
# [0 10 10 0 0]
# [0 0 10 0 0]]