mindspore_gl.HeterGraph

View Source On Gitee
class mindspore_gl.HeterGraph[source]

The heterogeneous Graph.

This is the class which should be annotated in the construct function for GNNCell class. The last argument in the ‘construct’ function will be resolved into the ‘mindspore_gl.HeterGraph’ heterogeneous class.

Supported Platforms:

Ascend GPU

Examples

>>> import mindspore as ms
>>> from mindspore_gl import Graph, HeterGraph, HeterGraphField
>>> from mindspore_gl.nn import GNNCell
>>> n_nodes = [9, 2]
>>> n_edges = [11, 1]
>>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)]
>>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)]
>>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32)
>>> class SrcIdx(GNNCell):
...     def construct(self, bg: HeterGraph):
...         return bg.src_idx
>>> ret = SrcIdx()(*heter_graph_field.get_heter_graph())
>>> print(ret)
[Tensor(shape=[11], dtype=Int32, value= [0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8]),
Tensor(shape=[1], dtype=Int32, value= [0])]
property dst_idx

A tensor with shape \((N\_EDGES)\), represents the destination node index of COO edge matrix.

Returns

  • List[Tensor], a list of destination vertex.

Examples

>>> import mindspore as ms
>>> from mindspore_gl import Graph, HeterGraph, HeterGraphField
>>> from mindspore_gl.nn import GNNCell
>>> n_nodes = [9, 2]
>>> n_edges = [11, 1]
>>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)]
>>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)]
>>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32)
...
>>> class DstIdx(GNNCell):
...     def construct(self, x, bg: HeterGraph):
...         return bg.dst_idx
>>> ret = DstIdx()(node_feat, *heter_graph_field.get_heter_graph())
>>> print(ret)
[Tensor(shape=[11], dtype=Int32, value= [1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8]),
Tensor(shape=[1], dtype=Int32, value= [1])]
get_homo_graph(etype)[source]

Get the specific nodes, edges for etype.

Parameters

etype (int) – The edge type.

Returns

List[Tensor], a homo graph.

Examples

>>> import mindspore as ms
>>> from mindspore_gl import Graph, HeterGraph, HeterGraphField
>>> from mindspore_gl.nn import GNNCell
>>> n_nodes = [9, 2]
>>> n_edges = [11, 1]
>>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)]
>>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)]
>>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32)
...
>>> class TestSum(GNNCell):
...     def construct(self, x, g: Graph):
...         g.set_vertex_attr({"x": x})
...         for v in g.dst_vertex:
...             v.h = g.sum([u.x for u in v.innbs])
...         return [v.h for v in g.dst_vertex]
...
>>> class TestHeterGraph(GNNCell):
...     def __init__(self):
...         super().__init__()
...         self.sum = TestSum()
...
...     def construct(self, x, hg: HeterGraph):
...         return self.sum(x, *hg.get_homo_graph(0))
...
>>> ret = TestHeterGraph()(node_feat, *heter_graph_field.get_heter_graph()).asnumpy().tolist()
>>> print(ret)
[[1.0], [2.0], [0.0], [0.0], [3.0], [2.0], [1.0], [0.0], [3.0]]
property n_edges

A list of integer, represent the edges count of the graph.

Returns

  • List[int], a list edges numbers of the graph.

Examples

>>> import mindspore as ms
>>> from mindspore_gl import Graph, HeterGraph, HeterGraphField
>>> from mindspore_gl.nn import GNNCell
>>> n_nodes = [9, 2]
>>> n_edges = [11, 1]
>>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)]
>>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)]
>>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32)
...
>>> class NEdges(GNNCell):
...     def construct(self, x, bg: HeterGraph):
...         return bg.n_edges
>>> ret = NEdges()(node_feat, *heter_graph_field.get_heter_graph())
>>> print(ret)
[11, 1]
property n_nodes

A list of integer, represent the nodes count of the graph.

Returns

  • List, a list of nodes numbers of the graph.

Examples

>>> import mindspore as ms
>>> from mindspore_gl import Graph, HeterGraph, HeterGraphField
>>> from mindspore_gl.nn import GNNCell
>>> n_nodes = [9, 2]
>>> n_edges = [11, 1]
>>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)]
>>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)]
>>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32)
...
>>> class NNodes(GNNCell):
...     def construct(self, x, bg: HeterGraph):
...         return bg.n_nodes
>>> ret = NNodes()(node_feat, *heter_graph_field.get_heter_graph())
>>> print(ret)
[9, 2]
property src_idx

A tensor with shape \((N\_EDGES)\), represents the source node index of COO edge matrix.

Returns

  • List[Tensor], a list of source vertex.

Examples

>>> import mindspore as ms
>>> from mindspore_gl import Graph, HeterGraph, HeterGraphField
>>> from mindspore_gl.nn import GNNCell
>>> n_nodes = [9, 2]
>>> n_edges = [11, 1]
>>> src_idx = [ms.Tensor([0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8], ms.int32), ms.Tensor([0], ms.int32)]
>>> dst_idx = [ms.Tensor([1, 0, 1, 5, 3, 4, 6, 4, 8, 8, 8], ms.int32), ms.Tensor([1], ms.int32)]
>>> heter_graph_field = HeterGraphField(src_idx, dst_idx, n_nodes, n_edges)
>>> node_feat = ms.Tensor([[1], [2], [1], [2], [0], [1], [2], [3], [1]], ms.float32)
...
>>> class SrcIdx(GNNCell):
...     def construct(self, x, bg: HeterGraph):
...         return bg.src_idx
>>> ret = SrcIdx()(node_feat, *heter_graph_field.get_heter_graph())
>>> print(ret)
[Tensor(shape=[11], dtype=Int32, value= [0, 2, 2, 3, 4, 5, 5, 6, 8, 8, 8]),
Tensor(shape=[1], dtype=Int32, value= [0])]