Spatio-Temporal Graph Training Network

View Source On Gitee   

Overview

In this example, it will show how to forecast the traffic by Spatio-temporal Graph Convolutional Networks.

Spatio-Temporal Graph Convolutional Networks (STGCN) can tackle the time series prediction problem in traffic domain. Experiments show that STGCN effectively captures comprehensive spatio-temporal correlations through modeling multi-scale traffic networks.

METR-LA is a large-scale data set collected from 1,500 traffic loop detectors in the Los Angeles rural road network. This data set includes speed, road capacity, and occupancy data and covers approximately 3,420 miles. The road network is constructed into a graph and input to the STGCN network. The road network information in the next time phase is predicted based on the historical data.

The node feature shape of a general graph is (nodes number, feature dimension), but the feature shape of a spatio-temporal graph is usually at least 3-dimensional (nodes number, feature dimension, time step), and the feature fusion processing of neighbor nodes will be more complicated. And due to the convolution in the time dimension, the time step will also change. When calculating the loss, it is necessary to calculate the output time length in advance.

Download the complete sample code here: STGCN.

STGCN Principles

Paper: A deep learning framework for traffic forecasting

Graph Laplacian Normalization

The self-loop of the graph is deleted, and the graph is normalized to obtain the new edge index and edge weight. mindspore_gl.graph implements norm, which can be used for laplacian normalization. The code for normalization of edge index and edge weight is as follows:

mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
edge_attr = edge_attr[mask]

edge_index = ms.Tensor(edge_index, ms.int32)
edge_attr = ms.Tensor(edge_attr, ms.float32)
edge_index, edge_weight = norm(edge_index, node_num, edge_attr, args.normalization)

For details about laplacian normalization, see the API code of mindspore_gl.graph.norm.

Defining a Network Model

mindspore_gl.nn implements STConv, which can be directly imported for use. Different from the general graph convolution layer, the input features of STConv are 4-dimensional, that is, (batch graphs number, time step, nodes number, feature dimension). The time step of the output feature needs to be calculated according to the size of the 1D convolution kernel and the times of convolutions.

The code for implementing a two-layer STGCN network using STConv is as follows:

class STGcnNet(GNNCell):
    """ STGCN Net """
    def __init__(self,
                 num_nodes: int,
                 in_channels: int,
                 hidden_channels_1st: int,
                 out_channels_1st: int,
                 hidden_channels_2nd: int,
                 out_channels_2nd: int,
                 out_channels: int,
                 kernel_size: int,
                 k: int,
                 bias: bool = True):
        super().__init__()
        self.layer0 = STConv(num_nodes, in_channels,
                             hidden_channels_1st,
                             out_channels_1st,
                             kernel_size,
                             k, bias)
        self.layer1 = STConv(num_nodes, out_channels_1st,
                             hidden_channels_2nd,
                             out_channels_2nd,
                             kernel_size,
                             k, bias)
        self.relu = ms.nn.ReLU()
        self.fc = ms.nn.Dense(out_channels_2nd, out_channels)

    def construct(self, x, edge_weight, g: Graph):
        x = self.layer0(x, edge_weight, g)
        x = self.layer1(x, edge_weight, g)
        x = self.relu(x)
        x = self.fc(x)
        return x

For details about STConv implementation, see the API code of mindspore_gl.nn.temporal.STConv.

Defining a Loss Function

Since this task is a regression task, the minimum mean square error can be used as the loss function. In this example, mindspore.nn.MSELoss is used to implement a mean square error loss.

class LossNet(GNNCell):
    """ LossNet definition """
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.loss_fn = nn.loss.MSELoss()

    def construct(self, feat, edges, target, g: Graph):
        """STGCN Net with loss function"""
        predict = self.net(feat, edges, g)
        predict = ops.Squeeze()(predict)
        loss = self.loss_fn(predict, target)
        return ms.ops.ReduceMean()(loss)

Constructing a Dataset

Input feature is (batch graphs number, time step, nodes number, feature dimension). The length of the time series changed after time convolution. Therefore, the input and output timestamps must be specified when features and tags are obtained from datasets. Otherwise, the shape of the predicted value is inconsistent with that of the label value.

For details about the restriction specifications, see the code comments.

from mindspore_gl.dataset import MetrLa
metr = MetrLa(args.data_path)
# out_timestep setting
# out_timestep = in_timestep - ((kernel_size - 1) * 2 * layer_nums)
# such as: layer_nums = 2, kernel_size = 3, in_timestep = 12,
# out_timestep = 4
features, labels = metr.get_data(args.in_timestep, args.out_timestep)

The MetrLa data can be downloaded and decompressed to args.data_path.

Network Training and Validation

Setting Environment Variables

The method of setting environment variables is similar to that of setting GCN.

Defining a Training Network

Instantiation of the model body STGcnNet and LossNet and optimizer. The implementation method is similar to that of the GCN.

Network Training and Validation

The implementation method is similar to that of the GCN.

Executing Jobs and Viewing Results

Running Process

After running the program, translate the code and start training.

Execution Results

Run the trainval_metr.py script to start training.

cd model_zoo/stgcn
python trainval_metr.py --data-path={path} --fuse=True

{path} indicates the dataset storage path.

The training result is as follows:

...
Iteration/Epoch: 600:199 loss: 0.21488506
Iteration/Epoch: 700:199 loss: 0.21441595
Iteration/Epoch: 800:199 loss: 0.21243602
Time 13.162885904312134 Epoch loss 0.21053028
eval MSE: 0.2060675

MSE on MetrLa: 0.206