mindspore_gl.nn.ASTGCN
- class mindspore_gl.nn.ASTGCN(n_blocks: int, in_channels: int, k: int, n_chev_filters: int, n_time_filters: int, time_conv_strides: int, num_for_predict: int, len_input: int, n_vertices: int, normalization: Optional[str] = None, bias: bool = True)[源代码]
基于Attention的时空图卷积网络。来自于论文 Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting 。
- 参数:
n_blocks (int) - ASTGCN块数。
in_channels (int) - 输入节点特征大小。
k (int) - Chebyshev polynomials的阶。
n_chev_filters (int) - Chebyshev过滤器的数量。
n_time_filters (int) - 时间过滤器的数量。
time_conv_strides (int) - 时间卷积期间的时间步长。
num_for_predict (int) - 未来要进行的预测数。
len_input (int) - 输入序列的长度。
n_vertices (int) - 图中的顶点数。
normalization (str, 可选) - 图Laplacian的归一化方案。默认值:None。
bias (bool, 可选) - layer是否学习加性偏差。默认值:True。
- 输入:
x (Tensor) - 输入节点T个时间段的特征。Shape为 \((B,N,F_{in},T_{in})\) 其中 \(N\) 是节点数。
g (Graph) - 输入图。
- 输出:
Tensor,输出节点特征,shape为 \((B,N,T_{out})\) 。
- 异常:
TypeError - 如果 n_blocks 、 in_channels 、 k 、 n_chev_filters 、 n_time_filters 、 time_conv_strides 、num_for_predict 、 len_input 或 n_vertices 不是正整数。
ValueError - 如果 normalization 不是 ‘sym’ 。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore_gl.graph import norm >>> from mindspore_gl.nn import ASTGCN >>> from mindspore_gl import GraphField >>> node_count = 5 >>> num_for_predict = 4 >>> len_input = 4 >>> n_time_strides = 1 >>> node_features = 2 >>> nb_block = 2 >>> k = 3 >>> n_chev_filters = 8 >>> n_time_filters = 8 >>> batch_size = 2 >>> normalization = "sym" >>> edge_index = np.array([[0, 0, 0, 0, 1, 1, 1, 2, 2, 3], [1, 4, 2, 3, 2, 3, 4, 3, 4, 4]]) >>> model = ASTGCN(nb_block, node_features, k, n_chev_filters, n_time_filters, n_time_strides, num_for_predict, len_input, node_count, normalization) >>> edge_index_norm, edge_weight_norm = norm(Tensor(edge_index, dtype=ms.int32), node_count) >>> graph = GraphField(edge_index_norm[1], edge_index_norm[0], node_count, len(edge_index_norm[0])) >>> x_seq = Tensor(np.ones([batch_size, node_count, node_features, len_input]), dtype=ms.float32) >>> output = model(x_seq, edge_weight_norm, *graph.get_graph()) >>> print(output.shape) (2, 5, 4)