mindspore_gl.HeterGraph
- class mindspore_gl.HeterGraph[source]
The heterogeneous Graph.
This is the class which should be annotated in the construct function for GNNCell class. The last argument in the ‘construct’ function will be resolved into the ‘mindspore_gl.HeterGraph’ heterogeneous class.
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore as ms >>> from mindspore_gl import Graph, HeterGraph, HeterGraphField >>> from mindspore_gl.nn import GNNCell >>> n_nodes = [9, 2] >>> n_edges = [11, 1] >>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)] >>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)] >>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges) >>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32) >>> class SrcIdx(GNNCell): ... def construct(self, bg: HeterGraph): ... return bg.src_idx >>> ret = SrcIdx()(*heter_graph_field.get_heter_graph()) >>> print(ret) [Tensor(shape=[11], dtype=Int32, value= [0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8]), Tensor(shape=[1], dtype=Int32, value= [0])]
- property dst_idx
A tensor with shape \((N\_EDGES)\), represents the destination node index of COO edge matrix.
- Returns
List[Tensor], a list of destination vertex.
Examples
>>> import mindspore as ms >>> from mindspore_gl import Graph, HeterGraph, HeterGraphField >>> from mindspore_gl.nn import GNNCell >>> n_nodes = [9, 2] >>> n_edges = [11, 1] >>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)] >>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)] >>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges) >>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32) ... >>> class DstIdx(GNNCell): ... def construct(self, x, bg: HeterGraph): ... return bg.dst_idx >>> ret = DstIdx()(node_feat, *heter_graph_field.get_heter_graph()) >>> print(ret) [Tensor(shape=[11], dtype=Int32, value= [1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8]), Tensor(shape=[1], dtype=Int32, value= [1])]
- get_homo_graph(etype)[source]
Get the specific nodes, edges for etype.
- Parameters
etype (int) – The edge type.
- Returns
List[Tensor], a homo graph.
Examples
>>> import mindspore as ms >>> from mindspore_gl import Graph, HeterGraph, HeterGraphField >>> from mindspore_gl.nn import GNNCell >>> n_nodes = [9, 2] >>> n_edges = [11, 1] >>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)] >>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)] >>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges) >>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32) ... >>> class TestSum(GNNCell): ... def construct(self, x, g: Graph): ... g.set_vertex_attr({"x": x}) ... for v in g.dst_vertex: ... v.h = g.sum([u.x for u in v.innbs]) ... return [v.h for v in g.dst_vertex] ... >>> class TestHeterGraph(GNNCell): ... def __init__(self): ... super().__init__() ... self.sum = TestSum() ... ... def construct(self, x, hg: HeterGraph): ... return self.sum(x, *hg.get_homo_graph(0)) ... >>> ret = TestHeterGraph()(node_feat, *heter_graph_field.get_heter_graph()).asnumpy().tolist() >>> print(ret) [[1.0], [2.0], [0.0], [0.0], [3.0], [2.0], [1.0], [0.0], [3.0]]
- property n_edges
A list of integer, represent the edges count of the graph.
- Returns
List[int], a list edges numbers of the graph.
Examples
>>> import mindspore as ms >>> from mindspore_gl import Graph, HeterGraph, HeterGraphField >>> from mindspore_gl.nn import GNNCell >>> n_nodes = [9, 2] >>> n_edges = [11, 1] >>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)] >>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)] >>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges) >>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32) ... >>> class NEdges(GNNCell): ... def construct(self, x, bg: HeterGraph): ... return bg.n_edges >>> ret = NEdges()(node_feat, *heter_graph_field.get_heter_graph()) >>> print(ret) [11, 1]
- property n_nodes
A list of integer, represent the nodes count of the graph.
- Returns
List, a list of nodes numbers of the graph.
Examples
>>> import mindspore as ms >>> from mindspore_gl import Graph, HeterGraph, HeterGraphField >>> from mindspore_gl.nn import GNNCell >>> n_nodes = [9, 2] >>> n_edges = [11, 1] >>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)] >>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)] >>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges) >>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32) ... >>> class NNodes(GNNCell): ... def construct(self, x, bg: HeterGraph): ... return bg.n_nodes >>> ret = NNodes()(node_feat, *heter_graph_field.get_heter_graph()) >>> print(ret) [9, 2]
- property src_idx
A tensor with shape \((N\_EDGES)\), represents the source node index of COO edge matrix.
- Returns
List[Tensor], a list of source vertex.
Examples
>>> import mindspore as ms >>> from mindspore_gl import Graph, HeterGraph, HeterGraphField >>> from mindspore_gl.nn import GNNCell >>> n_nodes = [9, 2] >>> n_edges = [11, 1] >>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)] >>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)] >>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges) >>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32) ... >>> class SrcIdx(GNNCell): ... def construct(self, x, bg: HeterGraph): ... return bg.src_idx >>> ret = SrcIdx()(node_feat, *heter_graph_field.get_heter_graph()) >>> print(ret) [Tensor(shape=[11], dtype=Int32, value= [0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8]), Tensor(shape=[1], dtype=Int32, value= [0])]