比较与torch.nn.MultiheadAttention的差异
torch.nn.MultiheadAttention
class torch.nn.MultiheadAttention(
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None
)(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
更多内容详见torch.nn.MultiheadAttention。
mindspore.nn.MultiheadAttention
class mindspore.nn.MultiheadAttention(
embed_dim,
num_heads,
dropout=0.0,
has_bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
dtype=mstype.float32
)(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)
差异对比
torch.nn.MultiheadAttention
和 mindspore.nn.MultiheadAttention
用法基本一致。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
embed_dim |
embed_dim |
功能一致 |
参数2 |
num_heads |
num_heads |
功能一致 |
|
参数3 |
dropout |
dropout |
功能一致 |
|
参数4 |
bias |
has_bias |
功能一致 |
|
参数5 |
add_bias_kv |
add_bias_kv |
功能一致 |
|
参数6 |
add_zero_attn |
add_zero_attn |
功能一致 |
|
参数7 |
kdim |
kdim |
功能一致 |
|
参数8 |
vdim |
vdim |
功能一致 |
|
参数9 |
batch_first |
MindSpore可配置第一维是否输出batch维度, PyTorch没有此功能。 |
||
参数10 |
dtype |
MindSpore可配置网络参数的dtype, PyTorch没有此功能。 |
||
输入 |
输入1 |
query |
query |
功能一致 |
输入2 |
key |
key |
功能一致 |
|
输入3 |
value |
value |
功能一致 |
|
输入4 |
key_padding_mask |
key_padding_mask |
MindSpore中dtype可设置为float或bool Tensor,PyTorch中dtype可设置为byte或bool Tensor |
|
输入5 |
need_weights |
need_weights |
功能一致 |
|
输入6 |
attn_mask |
attn_mask |
MindSpore中dtype可设置为float或bool Tensor,PyTorch中dtype可设置为float、byte或bool Tensor |
|
输入7 |
average_attn_weights |
如果为 True, 则返回值 attn_output_weights 为注意力头的平均值。如果为 False,则 attn_weights 分别返回每个注意力头的值。PyTorch没有此功能。 |
代码示例
# PyTorch
import torch
from torch import nn
embed_dim, num_heads = 128, 8
seq_length, batch_size = 10, 8
query = torch.rand(seq_length, batch_size, embed_dim)
key = torch.rand(seq_length, batch_size, embed_dim)
value = torch.rand(seq_length, batch_size, embed_dim)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(query, key, value)
print(attn_output.shape)
#torch.Size([10, 8, 128])
print(attn_output_weights.shape)
#torch.Size([8, 10, 10])
# MindSpore
import mindspore as ms
import numpy as np
embed_dim, num_heads = 128, 8
seq_length, batch_size = 10, 8
query = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
key = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
value = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
multihead_attn = ms.nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(query, key, value)
print(attn_output.shape)
#(10, 8, 128)
print(attn_output_weights.shape)
#(8, 10, 10)