Source code for mindspore.dataset.engine.graphdata

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
graphdata.py supports loading graph dataset for GNN network training,
and provides operations related to graph data.
"""
import numpy as np
from mindspore._c_dataengine import Graph
from mindspore._c_dataengine import Tensor

from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \
    check_gnn_get_node_feature


[docs]class GraphData: """ Reads the graph dataset used for GNN training from the shared file and database. Args: dataset_file (str): One of file names in dataset. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). """ @check_gnn_graphdata def __init__(self, dataset_file, num_parallel_workers=None): self._dataset_file = dataset_file if num_parallel_workers is None: num_parallel_workers = 1 self._graph = Graph(dataset_file, num_parallel_workers)
[docs] @check_gnn_get_all_nodes def get_all_nodes(self, node_type): """ Get all nodes in the graph. Args: node_type (int): Specify the type of node. Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) Raises: TypeError: If `node_type` is not integer. """ return self._graph.get_nodes(node_type, -1).as_array()
[docs] @check_gnn_get_all_neighbors def get_all_neighbors(self, node_list, neighbor_type): """ Get `neighbor_type` neighbors of the nodes in `node_list`. Args: node_list (list or numpy.ndarray): The given list of nodes. neighbor_type (int): Specify the type of neighbor. Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) >>> neighbors = data_graph.get_all_neighbors(nodes, 0) Raises: TypeError: If `node_list` is not list or ndarray. TypeError: If `neighbor_type` is not integer. """ return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
[docs] @check_gnn_get_node_feature def get_node_feature(self, node_list, feature_types): """ Get `feature_types` feature of the nodes in `node_list`. Args: node_list (list or numpy.ndarray): The given list of nodes. feature_types (list or ndarray): The given list of feature types. Returns: numpy.ndarray: array of features. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) >>> features = data_graph.get_node_feature(nodes, [1]) Raises: TypeError: If `node_list` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray. """ if isinstance(node_list, list): node_list = np.array(node_list, dtype=np.int32) return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]