Spatio-Temporal Graph Training Network
Overview
In this example, it will show how to forecast the traffic by Spatio-temporal Graph Convolutional Networks.
Spatio-Temporal Graph Convolutional Networks (STGCN) can tackle the time series prediction problem in traffic domain. Experiments show that STGCN effectively captures comprehensive spatio-temporal correlations through modeling multi-scale traffic networks.
METR-LA is a large-scale data set collected from 1,500 traffic loop detectors in the Los Angeles rural road network. This data set includes speed, road capacity, and occupancy data and covers approximately 3,420 miles. The road network is constructed into a graph and input to the STGCN network. The road network information in the next time phase is predicted based on the historical data.
The node feature shape of a general graph is (nodes number, feature dimension)
, but the feature shape of a spatio-temporal graph is usually at least 3-dimensional (nodes number, feature dimension, time step)
, and the feature fusion processing of neighbor nodes will be more complicated. And due to the convolution in the time dimension, the time step
will also change. When calculating the loss, it is necessary to calculate the output time length in advance.
Download the complete sample code here: STGCN.
STGCN Principles
Graph Laplacian Normalization
The self-loop of the graph is deleted, and the graph is normalized to obtain the new edge index and edge weight. mindspore_gl.graph implements norm, which can be used for laplacian normalization. The code for normalization of edge index and edge weight is as follows:
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
edge_attr = edge_attr[mask]
edge_index = ms.Tensor(edge_index, ms.int32)
edge_attr = ms.Tensor(edge_attr, ms.float32)
edge_index, edge_weight = norm(edge_index, node_num, edge_attr, args.normalization)
For details about laplacian normalization, see the API code of mindspore_gl.graph.norm.
Defining a Network Model
mindspore_gl.nn implements STConv, which can be directly imported for use. Different from the general graph convolution layer, the input features of STConv are 4-dimensional, that is, (batch graphs number, time step, nodes number, feature dimension)
.
The time step
of the output feature needs to be calculated according to the size of the 1D convolution kernel and the times of convolutions.
The code for implementing a two-layer STGCN network using STConv is as follows:
class STGcnNet(GNNCell):
""" STGCN Net """
def __init__(self,
num_nodes: int,
in_channels: int,
hidden_channels_1st: int,
out_channels_1st: int,
hidden_channels_2nd: int,
out_channels_2nd: int,
out_channels: int,
kernel_size: int,
k: int,
bias: bool = True):
super().__init__()
self.layer0 = STConv(num_nodes, in_channels,
hidden_channels_1st,
out_channels_1st,
kernel_size,
k, bias)
self.layer1 = STConv(num_nodes, out_channels_1st,
hidden_channels_2nd,
out_channels_2nd,
kernel_size,
k, bias)
self.relu = ms.nn.ReLU()
self.fc = ms.nn.Dense(out_channels_2nd, out_channels)
def construct(self, x, edge_weight, g: Graph):
x = self.layer0(x, edge_weight, g)
x = self.layer1(x, edge_weight, g)
x = self.relu(x)
x = self.fc(x)
return x
For details about STConv implementation, see the API code of mindspore_gl.nn.temporal.STConv.
Defining a Loss Function
Since this task is a regression task, the minimum mean square error can be used as the loss function. In this example, mindspore.nn.MSELoss is used to implement a mean square error loss.
class LossNet(GNNCell):
""" LossNet definition """
def __init__(self, net):
super().__init__()
self.net = net
self.loss_fn = nn.loss.MSELoss()
def construct(self, feat, edges, target, g: Graph):
"""STGCN Net with loss function"""
predict = self.net(feat, edges, g)
predict = ops.Squeeze()(predict)
loss = self.loss_fn(predict, target)
return ms.ops.ReduceMean()(loss)
Constructing a Dataset
Input feature is (batch graphs number, time step, nodes number, feature dimension)
. The length of the time series changed after time convolution. Therefore, the input and output timestamps must be specified when features and tags are obtained from datasets. Otherwise, the shape of the predicted value is inconsistent with that of the label value.
For details about the restriction specifications, see the code comments.
from mindspore_gl.dataset import MetrLa
metr = MetrLa(args.data_path)
# out_timestep setting
# out_timestep = in_timestep - ((kernel_size - 1) * 2 * layer_nums)
# such as: layer_nums = 2, kernel_size = 3, in_timestep = 12,
# out_timestep = 4
features, labels = metr.get_data(args.in_timestep, args.out_timestep)
The MetrLa data can be downloaded and decompressed to args.data_path.
Network Training and Validation
Setting Environment Variables
The method of setting environment variables is similar to that of setting GCN.
Defining a Training Network
Instantiation of the model body STGcnNet and LossNet and optimizer. The implementation method is similar to that of the GCN.
Network Training and Validation
The implementation method is similar to that of the GCN.
Executing Jobs and Viewing Results
Running Process
After running the program, translate the code and start training.
Execution Results
Run the trainval_metr.py script to start training.
cd model_zoo/stgcn
python trainval_metr.py --data-path={path} --fuse=True
{path}
indicates the dataset storage path.
The training result is as follows:
...
Iteration/Epoch: 600:199 loss: 0.21488506
Iteration/Epoch: 700:199 loss: 0.21441595
Iteration/Epoch: 800:199 loss: 0.21243602
Time 13.162885904312134 Epoch loss 0.21053028
eval MSE: 0.2060675
MSE on MetrLa: 0.206