mindspore_gl.nn.SAGPooling
- class mindspore_gl.nn.SAGPooling(in_channels: int, GNN=GCNConv2, activation=ms.nn.Tanh, multiplier=1.0)[源代码]
基于self-attention的池化操作。来自 Self-Attention Graph Pooling 和 Understanding Attention and Generalization in Graph Neural Networks 。
\[ \begin{align}\begin{aligned}\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]- 参数:
in_channels (int) - 每个输入样本的大小。
GNN (GNNCell) - 用于计算投影分数的图神经网络层,仅支持GCNConv2。默认值:mindspore_gl.nn.con.GCNConv2。
activation (Cell) - 要使用的非线性。默认值:mindspore.nn.Tanh。
multiplier (Float) - 用于缩放节点功能的标量。默认值:1.0。
- 输入:
x (Tensor) - 要更新的输入节点特征。Shape为 \((N, D)\) 其中 \(N\) 是节点数, \(D\) 是节点的特征大小,当 attn==None 时,D 应等于 Args 中的 in_feat_size 。
attn (Tensor) - 用于计算投影分数的输入节点特征。Shape为 \((N,D_{in})\) 其中 \(N\) 是节点数, \(D_{in}\) 应等于 Args 中的 in_feat_size 。 如果用 x 计算投影分数, attn 可以为None。
node_num (Int) - 以图g中的节点总数。
perm_num (Int) - Topk个节点过滤中k值。
g (BatchedGraph) - 输入图。
- 输出:
x (Tensor) - 更新的节点特征。Shape为 \(2, M, D_{out}\) 其中 \(M\) 等于 Inputs 中的 perm_num 和 \(D_{out}\) 等于 Inputs 中的 D 。
src_perm (Tensor) - 更新的src节点。
dst_perm (Tensor) - 更新的dst节点。
perm (Tensor) - 更新节点索引之前topk节点的节点索引。Shape为 \(M\),其中 \(M\) 等于 Inputs 中的 perm_num 。
perm_score (Tensor) - 更新节点的投影分数。
- 异常:
TypeError - 如果 in_feat_size 或 out_size 不是int。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore_gl.nn import SAGPooling >>> from mindspore_gl import BatchedGraphField >>> node_feat = ms.Tensor([[1, 2, 3, 4], [2, 4, 1, 3], [1, 3, 2, 4], ... [9, 7, 5, 8], [8, 7, 6, 5], [8, 6, 4, 6], [1, 2, 1, 1]], ... ms.float32) >>> n_nodes = 7 >>> n_edges = 8 >>> src_idx = ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6], ms.int32) >>> dst_idx = ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4], ms.int32) >>> ver_subgraph_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 1], ms.int32) >>> edge_subgraph_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 1, 1], ms.int32) >>> graph_mask = ms.Tensor([0, 1], ms.int32) >>> batched_graph_field = BatchedGraphField(src_idx, dst_idx, n_nodes, n_edges, ver_subgraph_idx, ... edge_subgraph_idx, graph_mask) >>> net = SAGPooling(4) >>> feature, src, dst, ver_subgraph, edge_subgraph, perm, perm_score = net(node_feat, None, 2, ... *batched_graph_field.get_batched_graph()) >>> print(feature.shape) (2, 2, 4)