mindspore_gl.nn.ASTGCN
- class mindspore_gl.nn.ASTGCN(n_blocks: int, in_channels: int, k: int, n_chev_filters: int, n_time_filters: int, time_conv_strides: int, num_for_predict: int, len_input: int, n_vertices: int, normalization: Optional[str] = 'sym', bias: bool = True)[source]
Attention Based Spatial-Temporal Graph Convolutional Networks. From the paper Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting .
- Parameters
n_blocks (int) – Number of ASTGCN Blocks
in_channels (int) – Input node feature size.
k (int) – Order of Chebyshev polynomials.
n_chev_filters (int) – Number of Chebyshev filters.
n_time_filters (int) – Number of time filters.
time_conv_strides (int) – Time strides during temporal convolution.
num_for_predict (int) – Number of predictions to make in the future.
len_input (int) – Length of the input sequence.
n_vertices (int) – Number of vertices in the graph.
normalization (str, optional) – The normalization scheme for the graph Laplacian. Default:
'sym'
. \((L)\) is normalized matrix, \((D)\) is degree matrix, \((A)\) is adjaceny matrix, \((I)\) is unit matrix. \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)bias (bool, optional) – Whether the layer will learn an additive bias. Default:
True
.
- Inputs:
x (Tensor) - The input node features for T time periods. The shape is \((B, N, F_{in}, T_{in})\) where \(N\) is the number of nodes,
g (Graph) - The input graph.
- Outputs:
Tensor, output node features with shape of \((B, N, T_{out})\).
- Raises
TypeError – If n_blocks, in_channels, k, n_chev_filters, n_time_filters, time_conv_strides, num_for_predict, len_input or n_vertices is not a positive int.
ValueError – If normalization is not
'sym'
.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore_gl.graph import norm >>> from mindspore_gl.nn import ASTGCN >>> from mindspore_gl import GraphField >>> node_count = 5 >>> num_for_predict = 4 >>> len_input = 4 >>> n_time_strides = 1 >>> node_features = 2 >>> nb_block = 2 >>> k = 3 >>> n_chev_filters = 8 >>> n_time_filters = 8 >>> batch_size = 2 >>> normalization = "sym" >>> edge_index = np.array([[0, 0, 0, 0, 1, 1, 1, 2, 2, 3], [1, 4, 2, 3, 2, 3, 4, 3, 4, 4]]) >>> model = ASTGCN(nb_block, node_features, k, n_chev_filters, n_time_filters, n_time_strides, num_for_predict, len_input, node_count, normalization) >>> edge_index_norm, edge_weight_norm = norm(Tensor(edge_index, dtype=ms.int32), node_count) >>> graph = GraphField(edge_index_norm[1], edge_index_norm[0], node_count, len(edge_index_norm[0])) >>> x_seq = Tensor(np.ones([batch_size, node_count, node_features, len_input]), dtype=ms.float32) >>> output = model(x_seq, edge_weight_norm, *graph.get_graph()) >>> print(output.shape) (2, 5, 4)