文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

PR

小问题,全程线上修改...

一键搞定!

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

基于GraphCast中期降水模块

下载Notebook下载样例代码查看源文件

概述

本模块基于GraphCast预训练模型,在下游承接可训练的GraphCast backbone进行微调,最终实现中期降水的预报。模型框架可见下图

graphcast_tp

技术路径

中期降水模型具体流程如下:

  1. 创建数据集

  2. 模型构建

  3. 损失函数

  4. 模型训练

  5. 模型评估与可视化

训练和测试所用数据集可以在: dataset下载

[36]:
import random
import numpy as np
import matplotlib.pyplot as plt

from mindspore import set_seed
from mindspore import context, ops

下述src可以在graphcast/src下载

[8]:
from mindearth.utils import load_yaml_config
from mindearth.data import Dataset

from src import get_coe, get_logger, init_tp_model
from src import LossNet, GraphCastTrainerTp, CustomWithLossCell, InferenceModuleTp
from src import Era5DataTp
[9]:
set_seed(0)
np.random.seed(0)
random.seed(0)

模型涉及的参数、优化器、数据配置见configs。执行降水代码时,GraphCastTp.yaml文件中的tp需要设置为True

[16]:
config = load_yaml_config("./GraphCastTp.yaml")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=5)

创建数据集

dataset路径下,下载正则化参数、训练数据集到./dataset目录。 修改./configs/GraphCastTp.yaml配置文件中的root_dir以及tp_dir参数,这两个参数分别设置了数据集和降水标签的路径。 ./dataset中的目录结构如下所示:

├── statistic
│   ├── mean.npy
│   ├── mean_s.npy
│   ├── std.npy
│   └── std_s.npy
│   └── climate_0.5_tp.npy
├── train
│   └── 2018
├── train_static
│   └── 2018
├── train_surface
│   └── 2018
├── train_surface_static
│   └── 2018
├── valid
│   └── 2021
├── valid_static
│   └── 2021
├── valid_surface
│   └── 2021
├── valid_surface_static
│   └── 2021

模型构建

若是训练代码,设置run_mode='train',还需要一个训练好的GraphCast模型ckpt,可以在这里ckpt下载,ckpt的路径在./GraphCastTp.yaml中的backbone_ckpt_path配置。

[17]:
model = init_tp_model(config, run_mode='train')
logger = get_logger(config)
2023-12-05 06:21:28,126 - utils.py[line:165] - INFO: {'name': 'GraphCastTp', 'latent_dims': 512, 'processing_steps': 10, 'recompute': True, 'vm_in_channels': 3, 'em_in_channels': 4, 'eg2m_in_channels': 4, 'em2g_in_channels': 4}
2023-12-05 06:21:28,126 - utils.py[line:165] - INFO: {'name': 'GraphCastTp', 'latent_dims': 512, 'processing_steps': 10, 'recompute': True, 'vm_in_channels': 3, 'em_in_channels': 4, 'eg2m_in_channels': 4, 'em2g_in_channels': 4}
2023-12-05 06:21:28,127 - utils.py[line:165] - INFO: {'name': 'era5', 'root_dir': './dataset_tp', 'feature_dims': 69, 'pressure_level_num': 13, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 1, 'valid_interval': 1, 'test_interval': 1, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2018, 2018], 'valid_period': [2021, 2021], 'test_period': [2022, 2022], 'patch': False, 'rollout_steps': 1, 'num_workers': 1, 'mesh_level': 5, 'grid_resolution': 0.5, 'tp': True, 'tp_dir': './dataset_tp/tp_log_data', 'h_size': 360, 'w_size': 720}
2023-12-05 06:21:28,127 - utils.py[line:165] - INFO: {'name': 'era5', 'root_dir': './dataset_tp', 'feature_dims': 69, 'pressure_level_num': 13, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 1, 'valid_interval': 1, 'test_interval': 1, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2018, 2018], 'valid_period': [2021, 2021], 'test_period': [2022, 2022], 'patch': False, 'rollout_steps': 1, 'num_workers': 1, 'mesh_level': 5, 'grid_resolution': 0.5, 'tp': True, 'tp_dir': './dataset_tp/tp_log_data', 'h_size': 360, 'w_size': 720}
2023-12-05 06:21:28,129 - utils.py[line:165] - INFO: {'name': 'adamw', 'initial_lr': 0.000125, 'finetune_lr': 3e-07, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.1, 'gamma': 0.5, 'epochs': 100}
2023-12-05 06:21:28,129 - utils.py[line:165] - INFO: {'name': 'adamw', 'initial_lr': 0.000125, 'finetune_lr': 3e-07, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.1, 'gamma': 0.5, 'epochs': 100}
2023-12-05 06:21:28,131 - utils.py[line:165] - INFO: {'summary_dir': 'GraphCastTp_latent_dims_512_processing_steps_10_recompute_True_vm_in_channels_3_em_in_channels_4_eg2m_in_channels_4_em2g_in_channels_4_adamw_oop', 'eval_interval': 10, 'save_checkpoint_steps': 5, 'keep_checkpoint_max': 10, 'save_rmse_acc': False, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '', 'backbone_ckpt_path': './dataset_tp/ckpt/GraphCast-device0-1_2008.ckpt'}
2023-12-05 06:21:28,131 - utils.py[line:165] - INFO: {'summary_dir': 'GraphCastTp_latent_dims_512_processing_steps_10_recompute_True_vm_in_channels_3_em_in_channels_4_eg2m_in_channels_4_em2g_in_channels_4_adamw_oop', 'eval_interval': 10, 'save_checkpoint_steps': 5, 'keep_checkpoint_max': 10, 'save_rmse_acc': False, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '', 'backbone_ckpt_path': './dataset_tp/ckpt/GraphCast-device0-1_2008.ckpt'}
2023-12-05 06:21:28,132 - utils.py[line:165] - INFO: {'name': 'oop', 'distribute': False, 'mixed_precision': True, 'amp_level': 'O2', 'load_ckpt': False}
2023-12-05 06:21:28,132 - utils.py[line:165] - INFO: {'name': 'oop', 'distribute': False, 'mixed_precision': True, 'amp_level': 'O2', 'load_ckpt': False}

损失函数

LP Loss, 考虑相对误差损失。

[18]:
sj_std, wj, ai = get_coe(config)
data_params = config.get('data')
loss_fn = LossNet(ai, wj, sj_std, data_params.get('feature_dims'), data_params['tp'])
loss_cell = CustomWithLossCell(backbone=model, loss_fn=loss_fn, data_params=data_params)
[ ]:
trainer = GraphCastTrainerTp(config, model, loss_cell, logger)
trainer.train()

模型评估和可视化

训练完成后我们使用第20个ckpt进行推理。

[33]:
inference_module = InferenceModuleTp(model, config, logger)
test_dataset_generator = Era5DataTp(data_params=data_params, run_mode='test')
test_dataset = Dataset(test_dataset_generator, distribute=False,
                       num_workers=data_params.get('num_workers'), shuffle=False)
test_dataset = test_dataset.create_dataset(data_params.get('batch_size'))
data = next(test_dataset.create_dict_iterator())
inputs = data['inputs']
labels = data['labels']
print(inputs.shape)
(1, 259200, 69)
[34]:
def unlog_trans(x, eps=1e-5):
    """Inverse transformation of log(TP / epsilon + 1)"""
    return eps * (ops.exp(x) - 1)

pred = inference_module.forecast(inputs)
labels = unlog_trans(labels).asnumpy()
pred = unlog_trans(pred).asnumpy()
print(labels.shape, pred.shape)
(1, 20, 360, 720) (1, 20, 360, 720)
[43]:
def plt_comparison(pred, label, root_dir='./'):
    plt.subplot(1, 2, 1)
    plt.imshow(label, cmap='jet')
    plt.title('Truth')
    plt.xticks(np.arange(0, 721, 180), np.arange(-180, 181, 90))
    plt.xlabel('longitude')
    plt.yticks(np.arange(0, 361, 180), np.arange(-90, 91, 90))
    plt.ylabel('latitude')
    plt.subplot(1, 2, 2)
    plt.imshow(pred, cmap='jet')
    plt.title('pred')
    plt.xticks(np.arange(0, 721, 180), np.arange(-180, 181, 90))
    plt.xlabel('longitude')
    plt.yticks(np.arange(0, 361, 180), np.arange(-90, 91, 90))
    plt.savefig(f"{root_dir}/tp_comparison.png", dpi=150)
    plt.show()
[44]:
def trans_colorbar(data):
    ori = [0., 1., 10., 100]
    new_v = [0., 50., 75, 100.]
    trans = []
    for i in range(1, len(ori)):
        x = np.where((data > ori[i-1]) & (data <= ori[i]), (data - ori[i-1]) * (new_v[i] - new_v[i-1]) / (ori[i] - ori[i-1]) + new_v[i-1], 1.)
        trans.append(x)
    res = 1.
    for x in trans:
        res *= x
    return res

plot_pred = trans_colorbar(pred[0, 1] * 1000)
plot_labels = trans_colorbar(labels[0, 1] * 1000)
plt_comparison(plot_pred, plot_labels) # 左图是真实的降水分布图,右图是模型预测的结果。
../_images/medium-range_graphcast_tp_18_0.png