mindearth.cell.GraphCastNet
- class mindearth.cell.GraphCastNet(vg_in_channels, vg_out_channels, vm_in_channels, em_in_channels, eg2m_in_channels, em2g_in_channels, latent_dims, processing_steps, g2m_src_idx, g2m_dst_idx, m2m_src_idx, m2m_dst_idx, m2g_src_idx, m2g_dst_idx, mesh_node_feats, mesh_edge_feats, g2m_edge_feats, m2g_edge_feats, per_variable_level_mean, per_variable_level_std, recompute=False)[source]
The GraphCast is based on graph neural networks and a novel high-resolution multi-scale mesh representation autoregressive model. The details can be found in GraphCast: Learning skillful medium-range global weather forecasting.
- Parameters
vg_in_channels (int) – The grid node dimensions.
vg_out_channels (int) – The grid node final dimensions.
vm_in_channels (int) – The mesh node dimensions.
em_in_channels (int) – The mesh edge dimensions.
eg2m_in_channels (int) – The grid to mesh edge dimensions.
em2g_in_channels (int) – The mesh to grid edge dimensions.
latent_dims (int) – The number of dims of hidden layers.
processing_steps (int) – The number of processing steps.
g2m_src_idx (Tensor) – The source node index of grid to mesh edges.
g2m_dst_idx (Tensor) – The destination node index of grid to mesh edges.
m2m_src_idx (Tensor) – The source node index of mesh to mesh edges.
m2m_dst_idx (Tensor) – The destination node index of mesh to mesh edges.
m2g_src_idx (Tensor) – The source node index of mesh to grid edges.
m2g_dst_idx (Tensor) – The destination node index of mesh to grid edges.
mesh_node_feats (Tensor) – The features of mesh nodes.
mesh_edge_feats (Tensor) – The features of mesh edges.
g2m_edge_feats (Tensor) – The features of grid to mesh edges.
m2g_edge_feats (Tensor) – The features of mesh to grid edges.
per_variable_level_mean (Tensor) – The mean of the per-variable-level inverse variance of time differences.
per_variable_level_std (Tensor) – The standard deviation of the per-variable-level inverse variance of time differences.
recompute (bool, optional) – Determine whether to recompute. Default:
False
.
- Inputs:
input (Tensor) - Tensor of shape \((batch\_size, height\_size * width\_size, feature\_size)\) .
- Outputs:
output (Tensor) - Tensor of shape \((height\_size * width\_size, feature\_size)\) .
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import context, Tensor >>> from mindearth.cell.graphcast.graphcastnet import GraphCastNet >>> >>> mesh_node_num = 2562 >>> grid_node_num = 32768 >>> mesh_edge_num = 20460 >>> g2m_edge_num = 50184 >>> m2g_edge_num = 98304 >>> vm_in_channels = 3 >>> em_in_channels = 4 >>> eg2m_in_channels = 4 >>> em2g_in_channels = 4 >>> feature_num = 69 >>> g2m_src_idx = Tensor(np.random.randint(0, grid_node_num, size=[g2m_edge_num]), ms.int32) >>> g2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[g2m_edge_num]), ms.int32) >>> m2m_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2g_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[m2g_edge_num]), ms.int32) >>> m2g_dst_idx = Tensor(np.random.randint(0, grid_node_num, size=[m2g_edge_num]), ms.int32) >>> mesh_node_feats = Tensor(np.random.rand(mesh_node_num, vm_in_channels).astype(np.float32), ms.float32) >>> mesh_edge_feats = Tensor(np.random.rand(mesh_edge_num, em_in_channels).astype(np.float32), ms.float32) >>> g2m_edge_feats = Tensor(np.random.rand(g2m_edge_num, eg2m_in_channels).astype(np.float32), ms.float32) >>> m2g_edge_feats = Tensor(np.random.rand(m2g_edge_num, em2g_in_channels).astype(np.float32), ms.float32) >>> per_variable_level_mean = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> per_variable_level_std = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> grid_node_feats = Tensor(np.random.rand(grid_node_num, feature_num).astype(np.float32), ms.float32) >>> graphcast_model = GraphCastNet(vg_in_channels=feature_num, >>> vg_out_channels=feature_num, >>> vm_in_channels=vm_in_channels, >>> em_in_channels=em_in_channels, >>> eg2m_in_channels=eg2m_in_channels, >>> em2g_in_channels=em2g_in_channels, >>> latent_dims=512, >>> processing_steps=4, >>> g2m_src_idx=g2m_src_idx, >>> g2m_dst_idx=g2m_dst_idx, >>> m2m_src_idx=m2m_src_idx, >>> m2m_dst_idx=m2m_dst_idx, >>> m2g_src_idx=m2g_src_idx, >>> m2g_dst_idx=m2g_dst_idx, >>> mesh_node_feats=mesh_node_feats, >>> mesh_edge_feats=mesh_edge_feats, >>> g2m_edge_feats=g2m_edge_feats, >>> m2g_edge_feats=m2g_edge_feats, >>> per_variable_level_mean=per_variable_level_mean, >>> per_variable_level_std=per_variable_level_std) >>> out = graphcast_model(Tensor(grid_node_feats, ms.float32)) >>> print(out.shape) (32768, 69))