mindearth.cell.GraphCastNet
- class mindearth.cell.GraphCastNet(vg_in_channels, vg_out_channels, vm_in_channels, em_in_channels, eg2m_in_channels, em2g_in_channels, latent_dims, processing_steps, g2m_src_idx, g2m_dst_idx, m2m_src_idx, m2m_dst_idx, m2g_src_idx, m2g_dst_idx, mesh_node_feats, mesh_edge_feats, g2m_edge_feats, m2g_edge_feats, per_variable_level_mean, per_variable_level_std, recompute=False)[源代码]
GraphCast 基于一种新颖的基于图神经网络的高分辨率多尺度网格表示自回归模型。 有关更多详细信息,请参考论文 GraphCast: Learning skillful medium-range global weather forecasting 。
- 参数:
vg_in_channels (int) - grid网格节点输入尺寸。
vg_out_channels (int) - 网格节点输出尺寸。
vm_in_channels (int) - mesh网格节点输入尺寸。
em_in_channels (int) - 网格边缘尺寸。
eg2m_in_channels (int) - grid网格到mesh网格边缘尺寸。
em2g_in_channels (int) - mesh网格到grid网格边缘尺寸。
latent_dims (int) - 隐藏层的dim数量。
processing_steps (int) - 处理的步骤数。
g2m_src_idx (Tensor) - grid网格到mesh网格边的源节点索引。
g2m_dst_idx (Tensor) - grid网格到mesh网格边的目标节点索引。
m2m_src_idx (Tensor) - mesh网格源节点到mesh网格边的索引。
m2m_dst_idx (Tensor) - mesh网格到mesh网格边的目标节点索引。
m2g_src_idx (Tensor) - mesh网格到grid网格边的源节点索引。
m2g_dst_idx (Tensor) - mesh网格到grid网格边的目标节点索引。
mesh_node_feats (Tensor) - 网格节点的特征。
mesh_edge_feats (Tensor) - 网格边缘的特征。
g2m_edge_feats (Tensor) - grid网格到mesh网格边的特征。
m2g_edge_feats (Tensor) - mesh网格到grid网格边的特征。
per_variable_level_mean (Tensor) - 每个变量特定尺度的平均值。
per_variable_level_std (Tensor) - 每个变量特定尺度的方差。
recompute (bool, optional) - 设置是否重计算。 默认值:
False
。
- 输入:
input (Tensor) - shape为 \((batch\_size, height\_size * width\_size, feature\_size)\) 的Tensor。
- 输出:
Tensor,Graphcast网络的输出。
output (Tensor) - shape为 \((height\_size * width\_size, feature\_size)\) 的Tensor。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import context, Tensor >>> from mindearth.cell.graphcast.graphcastnet import GraphCastNet >>> >>> mesh_node_num = 2562 >>> grid_node_num = 32768 >>> mesh_edge_num = 20460 >>> g2m_edge_num = 50184 >>> m2g_edge_num = 98304 >>> vm_in_channels = 3 >>> em_in_channels = 4 >>> eg2m_in_channels = 4 >>> em2g_in_channels = 4 >>> feature_num = 69 >>> g2m_src_idx = Tensor(np.random.randint(0, grid_node_num, size=[g2m_edge_num]), ms.int32) >>> g2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[g2m_edge_num]), ms.int32) >>> m2m_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2m_dst_idx = Tensor(np.random.randint(0, mesh_node_num, size=[mesh_edge_num]), ms.int32) >>> m2g_src_idx = Tensor(np.random.randint(0, mesh_node_num, size=[m2g_edge_num]), ms.int32) >>> m2g_dst_idx = Tensor(np.random.randint(0, grid_node_num, size=[m2g_edge_num]), ms.int32) >>> mesh_node_feats = Tensor(np.random.rand(mesh_node_num, vm_in_channels).astype(np.float32), ms.float32) >>> mesh_edge_feats = Tensor(np.random.rand(mesh_edge_num, em_in_channels).astype(np.float32), ms.float32) >>> g2m_edge_feats = Tensor(np.random.rand(g2m_edge_num, eg2m_in_channels).astype(np.float32), ms.float32) >>> m2g_edge_feats = Tensor(np.random.rand(m2g_edge_num, em2g_in_channels).astype(np.float32), ms.float32) >>> per_variable_level_mean = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> per_variable_level_std = Tensor(np.random.rand(feature_num,).astype(np.float32), ms.float32) >>> grid_node_feats = Tensor(np.random.rand(grid_node_num, feature_num).astype(np.float32), ms.float32) >>> graphcast_model = GraphCastNet(vg_in_channels=feature_num, >>> vg_out_channels=feature_num, >>> vm_in_channels=vm_in_channels, >>> em_in_channels=em_in_channels, >>> eg2m_in_channels=eg2m_in_channels, >>> em2g_in_channels=em2g_in_channels, >>> latent_dims=512, >>> processing_steps=4, >>> g2m_src_idx=g2m_src_idx, >>> g2m_dst_idx=g2m_dst_idx, >>> m2m_src_idx=m2m_src_idx, >>> m2m_dst_idx=m2m_dst_idx, >>> m2g_src_idx=m2g_src_idx, >>> m2g_dst_idx=m2g_dst_idx, >>> mesh_node_feats=mesh_node_feats, >>> mesh_edge_feats=mesh_edge_feats, >>> g2m_edge_feats=g2m_edge_feats, >>> m2g_edge_feats=m2g_edge_feats, >>> per_variable_level_mean=per_variable_level_mean, >>> per_variable_level_std=per_variable_level_std) >>> out = graphcast_model(Tensor(grid_node_feats, ms.float32)) >>> print(out.shape) (32768, 69))