mindspore_gl.nn.conv.edgeconv 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EDGEConv Layer"""
import mindspore as ms
from mindspore._checkparam import Validator
from mindspore_gl import Graph
from .. import GNNCell


[文档]class EDGEConv(GNNCell): r""" EdgeConv layer. From the paper `Dynamic Graph CNN for Learning on Point Clouds <https://arxiv.org/pdf/1801.07829>`_ . .. math:: h_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} ( \Theta \cdot (h_j^{(l)} - h_i^{(l)}) + \Phi \cdot h_i^{(l)}) :math:`\mathcal{N}(i)` represents the neighbour node of :math:`i`. :math:`\Theta` and :math:`\Phi` represents linear layers. Args: in_feat_size (int): Input node feature size. out_feat_size (int): Output node feature size. batch_norm (bool): Whether use batch norm. bias (bool): Whether use bias. Default: True. Inputs: - **x** (Tensor): The input node features. The shape is :math:`(N,*)` where :math:`N` is the number of nodes, and :math:`*` could be of any shape. - **g** (Graph): The input graph. Outputs: - Tensor, output node features. The shape is :math:`(N, out\_feat\_size)`. Raises: TypeError: If 'in_feat_size' is not a positive int. TypeError: If 'out_feat_size' is not a positive int. TypeError: If 'batch_norm' is not a bool. TypeError: If 'bias' is not a bool. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import mindspore as ms >>> from mindspore_gl.nn import EDGEConv >>> from mindspore_gl import GraphField >>> n_nodes = 4 >>> n_edges = 8 >>> feat_size = 16 >>> src_idx = ms.Tensor([0, 0, 0, 1, 1, 1, 2, 3], ms.int32) >>> dst_idx = ms.Tensor([0, 1, 3, 1, 2, 3, 3, 2], ms.int32) >>> ones = ms.ops.Ones() >>> nodes_feat = ones((n_nodes, feat_size), ms.float32) >>> graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges) >>> out_size = 4 >>> conv = EDGEConv(feat_size, out_size, batch_norm=True) >>> ret = conv(nodes_feat, *graph_field.get_graph()) >>> print(ret.shape) (4, 4) """ # pylint: disable=arguments-differ def __init__(self, in_feat_size: int, out_feat_size: int, batch_norm: bool, bias=True): super().__init__() in_feat_size = Validator.check_positive_int(in_feat_size, "in_feat_size", self.cls_name) out_feat_size = Validator.check_positive_int(out_feat_size, "out_feat_size", self.cls_name) batch_norm = Validator.check_bool(batch_norm, 'batch_norm', self.cls_name) bias = Validator.check_bool(bias, 'bias', self.cls_name) self.batch_norm = batch_norm self.theta = ms.nn.Dense(in_feat_size, out_feat_size, has_bias=bias) self.phi = ms.nn.Dense(in_feat_size, out_feat_size, has_bias=bias) if batch_norm: self.bn = ms.nn.BatchNorm1d(out_feat_size) def construct(self, x, g: Graph): """ Construct function for EDGEConv. """ g.set_vertex_attr({"x": x, "phi": self.phi(x)}) for v in g.dst_vertex: if not self.batch_norm: v.h = g.max([self.theta(u.x - v.x) + v.phi for u in v.innbs]) else: v.h = g.max([self.bn(self.theta(u.x - v.x) + v.phi) for u in v.innbs]) return [v.h for v in g.dst_vertex]