基于GraphCast中期降水模块

下载Notebook下载样例代码查看源文件

概述

本模块基于GraphCast预训练模型,在下游承接可训练的GraphCast backbone进行微调,最终实现中期降水的预报。模型框架可见下图

graphcast_tp

技术路径

中期降水模型具体流程如下:

  1. 创建数据集

  2. 模型构建

  3. 损失函数

  4. 模型训练

  5. 模型评估与可视化

训练和测试所用数据集可以在: dataset下载

[36]:
import random
import numpy as np
import matplotlib.pyplot as plt

from mindspore import set_seed
from mindspore import context, ops

下述src可以在graphcast/src下载

[8]:
from mindearth.utils import load_yaml_config
from mindearth.data import Dataset

from src import get_coe, get_logger, init_tp_model
from src import LossNet, GraphCastTrainerTp, CustomWithLossCell, InferenceModuleTp
from src import Era5DataTp
[9]:
set_seed(0)
np.random.seed(0)
random.seed(0)

模型涉及的参数、优化器、数据配置见configs。执行降水代码时,GraphCastTp.yaml文件中的tp需要设置为True

[16]:
config = load_yaml_config("./GraphCastTp.yaml")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=5)

创建数据集

dataset路径下,下载正则化参数、训练数据集到./dataset目录。 修改./configs/GraphCastTp.yaml配置文件中的root_dir以及tp_dir参数,这两个参数分别设置了数据集和降水标签的路径。 ./dataset中的目录结构如下所示:

├── statistic
│   ├── mean.npy
│   ├── mean_s.npy
│   ├── std.npy
│   └── std_s.npy
│   └── climate_0.5_tp.npy
├── train
│   └── 2018
├── train_static
│   └── 2018
├── train_surface
│   └── 2018
├── train_surface_static
│   └── 2018
├── valid
│   └── 2021
├── valid_static
│   └── 2021
├── valid_surface
│   └── 2021
├── valid_surface_static
│   └── 2021

模型构建

若是训练代码,设置run_mode='train',还需要一个训练好的GraphCast模型ckpt,可以在这里ckpt下载,ckpt的路径在./GraphCastTp.yaml中的backbone_ckpt_path配置。

[17]:
model = init_tp_model(config, run_mode='train')
logger = get_logger(config)
2023-12-05 06:21:28,126 - utils.py[line:165] - INFO: {'name': 'GraphCastTp', 'latent_dims': 512, 'processing_steps': 10, 'recompute': True, 'vm_in_channels': 3, 'em_in_channels': 4, 'eg2m_in_channels': 4, 'em2g_in_channels': 4}
2023-12-05 06:21:28,126 - utils.py[line:165] - INFO: {'name': 'GraphCastTp', 'latent_dims': 512, 'processing_steps': 10, 'recompute': True, 'vm_in_channels': 3, 'em_in_channels': 4, 'eg2m_in_channels': 4, 'em2g_in_channels': 4}
2023-12-05 06:21:28,127 - utils.py[line:165] - INFO: {'name': 'era5', 'root_dir': './dataset_tp', 'feature_dims': 69, 'pressure_level_num': 13, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 1, 'valid_interval': 1, 'test_interval': 1, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2018, 2018], 'valid_period': [2021, 2021], 'test_period': [2022, 2022], 'patch': False, 'rollout_steps': 1, 'num_workers': 1, 'mesh_level': 5, 'grid_resolution': 0.5, 'tp': True, 'tp_dir': './dataset_tp/tp_log_data', 'h_size': 360, 'w_size': 720}
2023-12-05 06:21:28,127 - utils.py[line:165] - INFO: {'name': 'era5', 'root_dir': './dataset_tp', 'feature_dims': 69, 'pressure_level_num': 13, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 1, 'valid_interval': 1, 'test_interval': 1, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2018, 2018], 'valid_period': [2021, 2021], 'test_period': [2022, 2022], 'patch': False, 'rollout_steps': 1, 'num_workers': 1, 'mesh_level': 5, 'grid_resolution': 0.5, 'tp': True, 'tp_dir': './dataset_tp/tp_log_data', 'h_size': 360, 'w_size': 720}
2023-12-05 06:21:28,129 - utils.py[line:165] - INFO: {'name': 'adamw', 'initial_lr': 0.000125, 'finetune_lr': 3e-07, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.1, 'gamma': 0.5, 'epochs': 100}
2023-12-05 06:21:28,129 - utils.py[line:165] - INFO: {'name': 'adamw', 'initial_lr': 0.000125, 'finetune_lr': 3e-07, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.1, 'gamma': 0.5, 'epochs': 100}
2023-12-05 06:21:28,131 - utils.py[line:165] - INFO: {'summary_dir': 'GraphCastTp_latent_dims_512_processing_steps_10_recompute_True_vm_in_channels_3_em_in_channels_4_eg2m_in_channels_4_em2g_in_channels_4_adamw_oop', 'eval_interval': 10, 'save_checkpoint_steps': 5, 'keep_checkpoint_max': 10, 'save_rmse_acc': False, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '', 'backbone_ckpt_path': './dataset_tp/ckpt/GraphCast-device0-1_2008.ckpt'}
2023-12-05 06:21:28,131 - utils.py[line:165] - INFO: {'summary_dir': 'GraphCastTp_latent_dims_512_processing_steps_10_recompute_True_vm_in_channels_3_em_in_channels_4_eg2m_in_channels_4_em2g_in_channels_4_adamw_oop', 'eval_interval': 10, 'save_checkpoint_steps': 5, 'keep_checkpoint_max': 10, 'save_rmse_acc': False, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '', 'backbone_ckpt_path': './dataset_tp/ckpt/GraphCast-device0-1_2008.ckpt'}
2023-12-05 06:21:28,132 - utils.py[line:165] - INFO: {'name': 'oop', 'distribute': False, 'mixed_precision': True, 'amp_level': 'O2', 'load_ckpt': False}
2023-12-05 06:21:28,132 - utils.py[line:165] - INFO: {'name': 'oop', 'distribute': False, 'mixed_precision': True, 'amp_level': 'O2', 'load_ckpt': False}

损失函数

LP Loss, 考虑相对误差损失。

[18]:
sj_std, wj, ai = get_coe(config)
data_params = config.get('data')
loss_fn = LossNet(ai, wj, sj_std, data_params.get('feature_dims'), data_params['tp'])
loss_cell = CustomWithLossCell(backbone=model, loss_fn=loss_fn, data_params=data_params)
[1]:
trainer = GraphCastTrainerTp(config, model, loss_cell, logger)
trainer.train()

模型评估和可视化

训练完成后我们使用第20个ckpt进行推理。

[33]:
inference_module = InferenceModuleTp(model, config, logger)
test_dataset_generator = Era5DataTp(data_params=data_params, run_mode='test')
test_dataset = Dataset(test_dataset_generator, distribute=False,
                       num_workers=data_params.get('num_workers'), shuffle=False)
test_dataset = test_dataset.create_dataset(data_params.get('batch_size'))
data = next(test_dataset.create_dict_iterator())
inputs = data['inputs']
labels = data['labels']
print(inputs.shape)
(1, 259200, 69)
[34]:
def unlog_trans(x, eps=1e-5):
    """Inverse transformation of log(TP / epsilon + 1)"""
    return eps * (ops.exp(x) - 1)

pred = inference_module.forecast(inputs)
labels = unlog_trans(labels).asnumpy()
pred = unlog_trans(pred).asnumpy()
print(labels.shape, pred.shape)
(1, 20, 360, 720) (1, 20, 360, 720)
[43]:
def plt_comparison(pred, label, root_dir='./'):
    plt.subplot(1, 2, 1)
    plt.imshow(label, cmap='jet')
    plt.title('Truth')
    plt.xticks(np.arange(0, 721, 180), np.arange(-180, 181, 90))
    plt.xlabel('longitude')
    plt.yticks(np.arange(0, 361, 180), np.arange(-90, 91, 90))
    plt.ylabel('latitude')
    plt.subplot(1, 2, 2)
    plt.imshow(pred, cmap='jet')
    plt.title('pred')
    plt.xticks(np.arange(0, 721, 180), np.arange(-180, 181, 90))
    plt.xlabel('longitude')
    plt.yticks(np.arange(0, 361, 180), np.arange(-90, 91, 90))
    plt.savefig(f"{root_dir}/tp_comparison.png", dpi=150)
    plt.show()
[44]:
def trans_colorbar(data):
    ori = [0., 1., 10., 100]
    new_v = [0., 50., 75, 100.]
    trans = []
    for i in range(1, len(ori)):
        x = np.where((data > ori[i-1]) & (data <= ori[i]), (data - ori[i-1]) * (new_v[i] - new_v[i-1]) / (ori[i] - ori[i-1]) + new_v[i-1], 1.)
        trans.append(x)
    res = 1.
    for x in trans:
        res *= x
    return res

plot_pred = trans_colorbar(pred[0, 1] * 1000)
plot_labels = trans_colorbar(labels[0, 1] * 1000)
plt_comparison(plot_pred, plot_labels) # 左图是真实的降水分布图,右图是模型预测的结果。
../_images/medium-range_graphcast_tp_18_0.png