mindspore_gl.nn.GMMConv
- class mindspore_gl.nn.GMMConv(in_feat_size: int, out_feat_size: int, coord_dim: int, n_kernels: int, residual=False, bias=False, aggregator_type='sum')[源代码]
高斯混合模型卷积层。 来自论文 Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs 。
\[\begin{split}u_{ij} = f(x_i, x_j), x_j \in \mathcal{N}(i) \\ w_k(u) = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right) \\ h_i^{l+1} = \mathrm{aggregate}\left(\left\{\frac{1}{K} \sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right)\end{split}\]其中 \(u\) 表示顶点与它其中一个邻居之间的伪坐标,使用 函数 \(f\) ,其中 \(\Sigma_k^{-1}\) 和 \(\mu_k\) 是协方差的可学习参数矩阵和高斯核的均值向量。
- 参数:
in_feat_size (int) - 输入节点特征大小。
out_feat_size (int) - 输出节点特征大小。
coord_dim (int) - 伪坐标的维度。
n_kernels (int) - 内核数。
residual (bool, 可选) - 是否使用残差。默认值:False。
bias (bool, 可选) - 是否使用偏置。默认值:False。
aggregator_type (str, 可选) - 聚合器的类型。默认值:’sum’。
- 输入:
x (Tensor) - 输入节点特征。Shape为 \((N, D_{in})\) ,其中 \(N\) 是节点数, \(D_{in}\) 应等于参数中的 in_feat_size 。
pseudo (Tensor) - 伪坐标张量。
g (Graph) - 输入图。
- 输出:
Tensor,Shape为 \((N, D_{out})\) 应等于参数中的 out_size。
- 异常:
SyntaxError - 当 aggregation_type 不等于sum时。
TypeError - 如果 in_feat_size 或 out_feat_size 或 coord_dim 或 n_kernels 不是int。
TypeError - 如果 bias 或 resual 不是bool。
- 支持平台:
Ascend
GPU
样例:
>>> import mindspore as ms >>> from mindspore_gl.nn import GMMConv >>> from mindspore_gl import GraphField >>> n_nodes = 4 >>> n_edges = 7 >>> node_feat_size = 7 >>> src_idx = ms.Tensor([0, 1, 1, 2, 2, 3, 3], ms.int32) >>> dst_idx = ms.Tensor([0, 0, 2, 1, 3, 0, 1], ms.int32) >>> ones = ms.ops.Ones() >>> node_feat = ones((n_nodes, node_feat_size), ms.float32) >>> graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges) >>> meanconv = GMMConv(in_feat_size=node_feat_size, out_feat_size=2, coord_dim=3, n_kernels=2) >>> pseudo = ones((7, 3), ms.float32) >>> res = meanconv(node_feat, pseudo, *graph_field.get_graph()) >>> print(res.shape) (4, 2)