mindspore_gl.nn.GMMConv
- class mindspore_gl.nn.GMMConv(in_feat_size: int, out_feat_size: int, coord_dim: int, n_kernels: int, residual=False, bias=False, aggregator_type='sum')[source]
Gaussian mixture model convolutional layer. From the paper Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs .
\[\begin{split}u_{ij} = f(x_i, x_j), x_j \in \mathcal{N}(i) \\ w_k(u) = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right) \\ h_i^{l+1} = \mathrm{aggregate}\left(\left\{\frac{1}{K} \sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right)\end{split}\]where \(u\) represents the pseudo coordinate between the vertex and one of its neighbors, computed using the function \(f\), where \(\Sigma_k^{-1}\) and \(\mu_k\) are the learnable parameters of the covariance matrix and the mean vector of the Gaussian kernel.
- Parameters
in_feat_size (int) – Input node feature size.
out_feat_size (int) – Output node feature size.
coord_dim (int) – Dimension of pseudo-coordinates.
n_kernels (int) – Number of kernels.
residual (bool, optional) – Whether use residual. Default:
False
.bias (bool, optional) – Whether use bias. Default:
False
.aggregator_type (str, optional) – Type of aggregator, should be
'sum'
. Default:'sum'
.
- Inputs:
x (Tensor) - The input node features. The shape is \((N, D_{in})\) where \(N\) is the number of nodes, and \(D_{in}\) should be equal to in_feat_size in Args.
pseudo (Tensor) - Pseudo coordinate tensor.
g (Graph) - The input graph.
- Outputs:
Tensor, output node features with shape of \((N, D_{out})\), where \((D_{out})\) should be the same as out_size in Args.
- Raises
SyntaxError – when the aggregator type not equals to
'sum'
.TypeError – If in_feat_size or out_feat_size or coord_dim or n_kernels is not an int.
TypeError – If bias or residual is not a bool.
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore as ms >>> from mindspore_gl.nn import GMMConv >>> from mindspore_gl import GraphField >>> n_nodes = 4 >>> n_edges = 7 >>> node_feat_size = 7 >>> src_idx = ms.Tensor([0, 1, 1, 2, 2, 3, 3], ms.int32) >>> dst_idx = ms.Tensor([0, 0, 2, 1, 3, 0, 1], ms.int32) >>> ones = ms.ops.Ones() >>> node_feat = ones((n_nodes, node_feat_size), ms.float32) >>> graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges) >>> meanconv = GMMConv(in_feat_size=node_feat_size, out_feat_size=2, coord_dim=3, n_kernels=2) >>> pseudo = ones((7, 3), ms.float32) >>> res = meanconv(node_feat, pseudo, *graph_field.get_graph()) >>> print(res.shape) (4, 2)