mindspore_gl.nn.ASTGCN

class mindspore_gl.nn.ASTGCN(n_blocks: int, in_channels: int, k: int, n_chev_filters: int, n_time_filters: int, time_conv_strides: int, num_for_predict: int, len_input: int, n_vertices: int, normalization: Optional[str] = None, bias: bool = True)[源代码]

基于Attention的时空图卷积网络。来自于论文 Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting

参数:
  • n_blocks (int) - ASTGCN块数。

  • in_channels (int) - 输入节点特征大小。

  • k (int) - Chebyshev polynomials的阶。

  • n_chev_filters (int) - Chebyshev过滤器的数量。

  • n_time_filters (int) - 时间过滤器的数量。

  • time_conv_strides (int) - 时间卷积期间的时间步长。

  • num_for_predict (int) - 未来要进行的预测数。

  • len_input (int) - 输入序列的长度。

  • n_vertices (int) - 图中的顶点数。

  • normalization (str, 可选) - 图Laplacian的归一化方案。默认值:None。

  • bias (bool, 可选) - layer是否学习加性偏差。默认值:True。

输入:
  • x (Tensor) - 输入节点T个时间段的特征。Shape为 \((B,N,F_{in},T_{in})\) 其中 \(N\) 是节点数。

  • g (Graph) - 输入图。

输出:
  • Tensor,输出节点特征,shape为 \((B,N,T_{out})\)

异常:
  • TypeError - 如果 n_blocksin_channelskn_chev_filtersn_time_filterstime_conv_stridesnum_for_predictlen_inputn_vertices 不是正整数。

  • ValueError - 如果 normalization 不是 ‘sym’ 。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore_gl.graph import norm
>>> from mindspore_gl.nn import ASTGCN
>>> from mindspore_gl import GraphField
>>> node_count = 5
>>> num_for_predict = 4
>>> len_input = 4
>>> n_time_strides = 1
>>> node_features = 2
>>> nb_block = 2
>>> k = 3
>>> n_chev_filters = 8
>>> n_time_filters = 8
>>> batch_size = 2
>>> normalization = "sym"
>>> edge_index = np.array([[0, 0, 0, 0, 1, 1, 1, 2, 2, 3],
[1, 4, 2, 3, 2, 3, 4, 3, 4, 4]])
>>> model = ASTGCN(nb_block, node_features, k, n_chev_filters, n_time_filters,
n_time_strides, num_for_predict, len_input, node_count, normalization)
>>> edge_index_norm, edge_weight_norm = norm(Tensor(edge_index, dtype=ms.int32), node_count)
>>> graph = GraphField(edge_index_norm[1], edge_index_norm[0], node_count, len(edge_index_norm[0]))
>>> x_seq = Tensor(np.ones([batch_size, node_count, node_features, len_input]), dtype=ms.float32)
>>> output = model(x_seq, edge_weight_norm, *graph.get_graph())
>>> print(output.shape)
(2, 5, 4)