mindspore_gl.nn.GATv2Conv
- class mindspore_gl.nn.GATv2Conv(in_feat_size: int, out_size: int, num_attn_head: int, input_drop_out_rate: float = 0.0, attn_drop_out_rate: float = 0.0, leaky_relu_slope: float = 0.2, activation=None, add_norm=False)[source]
Graph Attention Network v2. From the paper How Attentive Are Graph Attention Networks?, which fixes the static attention problem of GATv2.
\[h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}\]\(\alpha_{i, j}\) represents the attention score between node \(i\) and node \(j\).
\[\begin{split}\alpha_{ij}^{l} = \mathrm{softmax_i} (e_{ij}^{l}) \\ e_{ij}^{l} = \vec{a}^T \mathrm{LeakyReLU}\left(W [h_{i} \| h_{j}]\right)\end{split}\]- Parameters
in_feat_size (int) – Input node feature size.
out_size (int) – Output node feature size.
num_attn_head (int) – Number of attention head used in GATv2.
input_drop_out_rate (float, optional) – Dropout rate of input drop out. Default:
0.0
.attn_drop_out_rate (float, optional) – Dropout rate of attention drop out. Default:
0.0
.leaky_relu_slope (float, optional) – Slope for leaky relu. Default:
0.2
.activation (Cell, optional) – Activation function. Default:
None
.add_norm (bool, optional) – Whether the edge information needs normalization or not. Default:
False
.
- Inputs:
x (Tensor) - The input node features. The shape is \((N,D_{in})\) where \(N\) is the number of nodes and \(D_{in}\) could be of any shape.
g (Graph) - The input graph.
- Outputs:
Tensor, the output feature of shape \((N,D_{out})\) where \(D_{out}\) should be equal to \(D_{in} * num\_attn\_head\).
- Raises
TypeError – If in_feat_size, out_size, or num_attn_head is not an int.
TypeError – If input_drop_out_rate, attn_drop_out_rate, or leaky_relu_slope is not a float.
TypeError – If activation is not a Cell.
ValueError – If input_drop_out_rate or attn_drop_out_rate is not in range [0.0, 1.0).
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore as ms >>> from mindspore_gl.nn import GATv2Conv >>> from mindspore_gl import GraphField >>> n_nodes = 4 >>> n_edges = 7 >>> feat_size = 4 >>> src_idx = ms.Tensor([0, 1, 1, 2, 2, 3, 3], ms.int32) >>> dst_idx = ms.Tensor([0, 0, 2, 1, 3, 0, 1], ms.int32) >>> ones = ms.ops.Ones() >>> feat = ones((n_nodes, feat_size), ms.float32) >>> graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges) >>> gatv2conv = GATv2Conv(in_feat_size=4, out_size=2, num_attn_head=3) >>> res = gatv2conv(feat, *graph_field.get_graph()) >>> print(res.shape) (4, 6)