时空图训练网络
概述
在本例中将展示如何利用时空图卷积网络进行交通预测。
时空图卷积网络(STGCN)可以解决交通域的时间序列预测问题。实验表明,STGCN通过建模多尺度交通网络,有效地捕获了综合的时空相关性。
METR-LA是一个大规模数据集,从洛杉矶乡村公路网的1500个交通环路探测器收集。此数据集包括速度、道路容量和占用数据,覆盖约3,420英里。将路网构建成图,输入到STGCN网络中,根据历史数据来预测下个时间段的路网信息。
一般图的节点特征形状为(节点数量, 特征维度)
,时空图中输入的特征形状通常至少为三维(节点数量, 特征维度, 时间步)
,邻居节点的特征融合处理会更加复杂。并且由于时间维度上进行卷积,时间步
也会发生变化,计算loss时,需要提前计算好输出时间长度。
下载完整的样例STGCN代码。
STGCN原理
图拉普拉斯归一化
将图的自环删除,对图进行归一化,得到新的边索引与边权重。 mindspore_gl.graph提供norm的API可以被用于拉普拉斯归一化。边缘索引和边缘权重归一化的代码如下所示:
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
edge_attr = edge_attr[mask]
edge_index = ms.Tensor(edge_index, ms.int32)
edge_attr = ms.Tensor(edge_attr, ms.float32)
edge_index, edge_weight = norm(edge_index, node_num, edge_attr, args.normalization)
关于拉普拉斯归一化的更多细节,可以看mindspore_gl.graph.norm的API.
定义网络结构
mindspore_gl.nn提供了STConv的API可以直接调用。与一般的图卷积层不同,STConv的输入特征为四维,即(批次内图数量, 时间步, 节点数量, 特征维度)
。输出特征的时间步
需要根据1D卷积核尺寸、卷积次数来计算。
使用STConv实现一个两层的STGCN网络代码如下:
class STGcnNet(GNNCell):
""" STGCN Net """
def __init__(self,
num_nodes: int,
in_channels: int,
hidden_channels_1st: int,
out_channels_1st: int,
hidden_channels_2nd: int,
out_channels_2nd: int,
out_channels: int,
kernel_size: int,
k: int,
bias: bool = True):
super().__init__()
self.layer0 = STConv(num_nodes, in_channels,
hidden_channels_1st,
out_channels_1st,
kernel_size,
k, bias)
self.layer1 = STConv(num_nodes, out_channels_1st,
hidden_channels_2nd,
out_channels_2nd,
kernel_size,
k, bias)
self.relu = ms.nn.ReLU()
self.fc = ms.nn.Dense(out_channels_2nd, out_channels)
def construct(self, x, edge_weight, g: Graph):
x = self.layer0(x, edge_weight, g)
x = self.layer1(x, edge_weight, g)
x = self.relu(x)
x = self.fc(x)
return x
STConv执行的更多细节可以看mindspore_gl.nn.temporal.STConv的API代码。
定义loss函数
由于本次任务为回归任务,可以采用最小均方差来作为损失函数。这里调用mindspore.nn.MSELoss实现最小均方差loss。
class LossNet(GNNCell):
""" LossNet definition """
def __init__(self, net):
super().__init__()
self.net = net
self.loss_fn = nn.loss.MSELoss()
def construct(self, feat, edges, target, g: Graph):
"""STGCN Net with loss function"""
predict = self.net(feat, edges, g)
predict = ops.Squeeze()(predict)
loss = self.loss_fn(predict, target)
return ms.ops.ReduceMean()(loss)
构造数据集
输入特征为(批次内图数量, 时间步, 节点数量, 特征维度)
。在时序卷积上时间序列的长度将会发生变化。因此,从数据集获取特征和标签时,输入和输出时间步有相应规范,否则会出现预测值与标签值形状不一致。
限制规范可以参考代码注释。
from mindspore_gl.dataset import MetrLa
metr = MetrLa(args.data_path)
# out_timestep setting
# out_timestep = in_timestep - ((kernel_size - 1) * 2 * layer_nums)
# such as: layer_nums = 2, kernel_size = 3, in_timestep = 12,
# out_timestep = 4
features, labels = metr.get_data(args.in_timestep, args.out_timestep)
其中MetrLa数据下载后,解压路径即为args.data_path。
网络训练和验证
设置环境变量
环境变量设置方法可以参考GCN。
定义训练网络
实例化模型主体以及LossNet和优化器。 实现方法可以参考GCN。
网络训练及验证
实现方法可以参考GCN。
执行并查看结果
运行过程
运行程序后,翻译代码并开始训练。
执行结果
执行脚本trainval_metr.py启动训练。
cd model_zoo/stgcn
python trainval_metr.py --data-path={path} --fuse=True
其中{path}
为数据集存放路径。
可以看到训练的结果如下:
...
Iteration/Epoch: 600:199 loss: 0.21488506
Iteration/Epoch: 700:199 loss: 0.21441595
Iteration/Epoch: 800:199 loss: 0.21243602
Time 13.162885904312134 Epoch loss 0.21053028
eval MSE: 0.2060675
MetrLa的MSE: 0.206