{ "cells": [ { "cell_type": "markdown", "id": "ecc50b48", "metadata": {}, "source": [ "# FuXi: 基于级联架构的全球中期天气预报\n", "\n", "[](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth_r0.2/zh_cn/medium-range/mindspore_fuxi.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth_r0.2/zh_cn/medium-range/mindspore_fuxi.py) [](https://gitee.com/mindspore/docs/blob/r2.2/docs/mindearth_r0.2/docs/source_zh_cn/medium-range/fuxi.ipynb)" ] }, { "cell_type": "markdown", "id": "0abd5fd1", "metadata": {}, "source": [ "## 概述\n", "\n", "FuXi模型是由复旦大学的研究人员开发的一个基于数据驱动的全球天气预报模型。它提供了关键全球天气指标的中期预报,分辨率为0.25°。相当于赤道附近约25公里 x 25公里的空间分辨率和大小为720 x 1440像素的全球网格。与以前的基于MachineLearning的天气预报模型相比,采用级联架构的FuXi模型在[EC中期预报评估](https://charts.ecmwf.int/products/plwww_3m_fc_aimodels_wp_mean?area=Northern%20Extra-tropics¶meter=Geopotential%20500hPa&score=Root%20mean%20square%20error)中取得了优异的结果。\n", "\n", "本教程介绍了FuXi的研究背景和技术路径,并展示了如何通过MindSpore Earth训练和快速推理模型。更多信息参见[文章](https://www.nature.com/articles/s41612-023-00512-1)。本教程中使用分辨率为0.25°的[ERA5_0_25_tiny400](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/ERA5_0_25_tiny400/)数据集,详细介绍案例的运行流程。" ] }, { "cell_type": "markdown", "id": "39dacb90", "metadata": {}, "source": [ "## FuXi\n", "\n", "基本的伏羲模型体系结构由三个主要组件组成,如图所示:Cube Embedding、U-Transformer和全连接层。输入数据结合了上层空气和地表变量,并创建了一个维度为69×720×1440的数据立方体,以一个时间步作为一个step。\n", "\n", "高维输入数据通过联合时空Cube Embedding进行维度缩减,转换为C×180×360。Cube Embedding的主要目的是减少输入数据的时空维度,减少冗余信息。随后,U-Transformer处理嵌入数据,并使用简单的全连接层进行预测,输出首先被重塑为69×720×1440。\n", "\n", "\n", "\n", "- Cube Embedding\n", "\n", " 为了减少输入数据的空间和时间维度,并加快训练过程,应用了Cube Embedding方法。\n", "\n", " 具体地,空时立方体嵌入采用了一个三维(3D)卷积层,卷积核和步幅分别为2×4×4(相当于$\\frac{T}{2}×\\frac{H}{2}×\\frac{W}{2}$),输出通道数为C。在空时立方体嵌入之后,采用了层归一化(LayerNorm)来提高训练的稳定性。最终得到的数据立方体的维度是C×180×360。\n", "\n", "- U-Transformer\n", "\n", " U-Transformer还包括U-Net模型的下采样和上采样块。下采样块在图中称为Down Block,将数据维度减少为C×90×180,从而最小化自注意力计算的计算和内存需求。Down Block由一个步长为2的3×3 2D卷积层和一个残差块组成,该残差块有两个3×3卷积层,后面跟随一个组归一化(GN)层和一个Sigmoid加权激活函数(SiLU)。SiLU加权激活函数通过将Sigmoid函数与其输入相乘来计算$σ(x)×x$。\n", "\n", " 上采样块在图中称为Up Block,它与Down Block使用相同的残差块,同时还包括一个2D反卷积,内核为2,步长为2。Up Block将数据大小缩放回$C×180×360$。此外,在馈送到Up Block之前,还包括一个跳跃连接,将Down Block的输出与Transformer Block的输出连接起来。\n", "\n", " 中间结构是由18个重复的Swin Transformer块构建而成,通过使用残差后归一化代替前归一化,缩放余弦注意力代替原始点积自注意力,Swin Transformer解决了诸如训练不稳定等训练和应用大规模的Swin Transformer模型会出现几个问题。" ] }, { "cell_type": "markdown", "id": "da40cb1d", "metadata": {}, "source": [ "## 技术路径\n", "\n", "MindSpore Earth求解该问题的具体流程如下:\n", "\n", "1. 创建数据集\n", "2. 模型构建\n", "3. 损失函数\n", "4. 模型训练\n", "5. 模型评估与可视化" ] }, { "cell_type": "code", "execution_count": 2, "id": "52aa45e6", "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from mindspore import set_seed\n", "from mindspore import context\n", "from mindspore import load_checkpoint, load_param_into_net" ] }, { "cell_type": "markdown", "id": "062e7adc", "metadata": {}, "source": [ "下述`src`可以在[fuxi/src](https://gitee.com/mindspore/mindscience/tree/r0.6/MindEarth/applications/medium-range/fuxi/src)下载。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "1e8373e6", "metadata": {}, "outputs": [], "source": [ "from mindearth.utils import load_yaml_config, plt_global_field_data, make_dir\n", "from mindearth.data import Dataset, Era5Data\n", "\n", "from src import init_model, get_logger\n", "from src import MAELossForMultiLabel, FuXiTrainer, CustomWithLossCell, InferenceModule" ] }, { "cell_type": "code", "execution_count": 4, "id": "258d0acd", "metadata": {}, "outputs": [], "source": [ "set_seed(0)\n", "np.random.seed(0)\n", "random.seed(0)" ] }, { "cell_type": "markdown", "id": "18b670e4", "metadata": {}, "source": [ "可以在[配置文件](https://gitee.com/mindspore/mindscience/raw/r0.6/MindEarth/applications/medium-range/fuxi/configs/FuXi.yaml)中配置模型、数据和优化器等参数。" ] }, { "cell_type": "code", "execution_count": 5, "id": "9ba71baf", "metadata": {}, "outputs": [], "source": [ "config = load_yaml_config(\"configs/FuXi.yaml\")\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=0)" ] }, { "cell_type": "markdown", "id": "4fa254aa", "metadata": {}, "source": [ "## 创建数据集\n", "\n", "下载[ERA5_0_25_tiny400](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/ERA5_0_25_tiny400/)数据集到`./dataset`目录。\n", "\n", "`./dataset`中的目录结构如下所示:\n", "\n", "```markdown\n", ".\n", "├── statistic\n", "│ ├── mean.npy\n", "│ ├── mean_s.npy\n", "│ ├── std.npy\n", "│ ├── std_s.npy\n", "│ └── climate_0.25.npy\n", "├── train\n", "│ └── 2015\n", "├── train_static\n", "│ └── 2015\n", "├── train_surface\n", "│ └── 2015\n", "├── train_surface_static\n", "│ └── 2015\n", "├── valid\n", "│ └── 2016\n", "├── valid_static\n", "│ └── 2016\n", "├── valid_surface\n", "│ └── 2016\n", "├── valid_surface_static\n", "│ └── 2016\n", "```" ] }, { "cell_type": "markdown", "id": "3f004a89", "metadata": {}, "source": [ "## 模型构建\n", "\n", "模型初始化主要包括Swin Transformer Block数目以及训练参数。" ] }, { "cell_type": "code", "execution_count": 6, "id": "cee87d14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2024-01-29 12:34:41,485 - utils.py[line:91] - INFO: {'name': 'FuXi', 'depths': 18, 'in_channels': 96, 'out_channels': 192}\n", "2024-01-29 12:34:41,487 - utils.py[line:91] - INFO: {'name': 'era5', 'root_dir': '/data4/cq/ERA5_0_25/', 'feature_dims': 69, 'pressure_level_num': 13, 'level_feature_size': 5, 'surface_feature_size': 4, 'h_size': 720, 'w_size': 1440, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 6, 'valid_interval': 6, 'test_interval': 6, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2015, 2015], 'valid_period': [2016, 2016], 'test_period': [2017, 2017], 'num_workers': 1, 'grid_resolution': 0.25}\n", "2024-01-29 12:34:41,488 - utils.py[line:91] - INFO: {'name': 'adam', 'initial_lr': 0.00025, 'finetune_lr': 1e-05, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.0, 'loss_weight': 0.25, 'gamma': 0.5, 'epochs': 100}\n", "2024-01-29 12:34:41,489 - utils.py[line:91] - INFO: {'summary_dir': './summary', 'eval_interval': 10, 'save_checkpoint_steps': 10, 'keep_checkpoint_max': 50, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '/data3/cq'}\n", "2024-01-29 12:34:41,490 - utils.py[line:91] - INFO: {'name': 'oop', 'distribute': False, 'amp_level': 'O2', 'load_ckpt': True}\n" ] } ], "source": [ "make_dir(os.path.join(config['summary'][\"summary_dir\"], \"image\"))\n", "logger_obj = get_logger(config)\n", "fuxi_model = init_model(config, \"train\")" ] }, { "cell_type": "markdown", "id": "e818c80f", "metadata": {}, "source": [ "## 损失函数\n", "\n", "FuXi在模型训练中使用自定义平均绝对误差。" ] }, { "cell_type": "code", "execution_count": 10, "id": "67156e88", "metadata": {}, "outputs": [], "source": [ "data_params = config.get('data')\n", "optimizer_params = config.get('optimizer')\n", "loss_fn = MAELossForMultiLabel(data_params=data_params, optimizer_params=optimizer_params)\n", "loss_cell = CustomWithLossCell(backbone=fuxi_model, loss_fn=loss_fn)\n" ] }, { "cell_type": "markdown", "id": "19a03c13", "metadata": {}, "source": [ "## 模型训练\n", "\n", "在本教程中,我们继承了`Trainer`并重写了`get_solver`成员函数来构建自定义损失函数,并重写了`get_callback`成员函数来在训练过程中对测试数据集执行推理。\n", "\n", "MindSpore Earth提供训练和推理接口,使用2.0.0及之后的MindSpore训练网络。" ] }, { "cell_type": "code", "execution_count": 15, "id": "51397466", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2023-12-29 08:46:01,280 - pretrain.py[line:215] - INFO: steps_per_epoch: 67\n", "epoch: 1 step: 67, loss is 0.37644267\n", "Train epoch time: 222879.994 ms, per step time: 3326.567 ms\n", "epoch: 2 step: 67, loss is 0.26096737\n", "Train epoch time: 216913.275 ms, per step time: 3237.512 ms\n", "epoch: 3 step: 67, loss is 0.2587443\n", "Train epoch time: 217870.765 ms, per step time: 3251.802 ms\n", "epoch: 4 step: 67, loss is 0.2280185\n", "Train epoch time: 218111.564 ms, per step time: 3255.396 ms\n", "epoch: 5 step: 67, loss is 0.20605856\n", "Train epoch time: 216881.674 ms, per step time: 3237.040 ms\n", "epoch: 6 step: 67, loss is 0.20178188\n", "Train epoch time: 218800.354 ms, per step time: 3265.677 ms\n", "epoch: 7 step: 67, loss is 0.21064804\n", "Train epoch time: 217554.571 ms, per step time: 3247.083 ms\n", "epoch: 8 step: 67, loss is 0.20392722\n", "Train epoch time: 217324.330 ms, per step time: 3243.647 ms\n", "epoch: 9 step: 67, loss is 0.19890495\n", "Train epoch time: 218374.032 ms, per step time: 3259.314 ms\n", "epoch: 10 step: 67, loss is 0.2064792\n", "Train epoch time: 217399.318 ms, per step time: 3244.766 ms\n", "2023-12-29 09:22:23,927 - forecast.py[line:223] - INFO: ================================Start Evaluation================================\n", "2023-12-29 09:24:51,246 - forecast.py[line:241] - INFO: test dataset size: 1\n", "2023-12-29 09:24:51,248 - forecast.py[line:191] - INFO: t = 6 hour: \n", "2023-12-29 09:24:51,250 - forecast.py[line:202] - INFO: RMSE of Z500: 313.254855370194, T2m: 2.911020155335285, T850: 1.6009748653510902, U10: 1.8822629694594444\n", "2023-12-29 09:24:51,251 - forecast.py[line:203] - INFO: ACC of Z500: 0.9950579247892839, T2m: 0.98743929296225, T850: 0.9930489273077082, U10: 0.9441216196638477\n", "2023-12-29 09:24:51,252 - forecast.py[line:191] - INFO: t = 72 hour: \n", "2023-12-29 09:24:51,253 - forecast.py[line:202] - INFO: RMSE of Z500: 1176.8557892319443, T2m: 7.344694139181644, T850: 6.165706260104667, U10: 5.953978905254709\n", "2023-12-29 09:24:51,254 - forecast.py[line:203] - INFO: ACC of Z500: 0.9271318752961824, T2m: 0.9236962494086007, T850: 0.9098796075852417, U10: 0.5003382663349598\n", "2023-12-29 09:24:51,255 - forecast.py[line:191] - INFO: t = 120 hour: \n", "2023-12-29 09:24:51,256 - forecast.py[line:202] - INFO: RMSE of Z500: 1732.662048442734, T2m: 9.891472332990181, T850: 8.233521390723434, U10: 7.434774900830313\n", "2023-12-29 09:24:51,256 - forecast.py[line:203] - INFO: ACC of Z500: 0.8421506711992445, T2m: 0.8468635778030965, T850: 0.8467625693884427, U10: 0.3787509969898105\n", "2023-12-29 09:24:51,257 - forecast.py[line:256] - INFO: ================================End Evaluation================================\n", "......\n", "epoch: 91 step: 67, loss is 0.13158562\n", "Train epoch time: 191498.866 ms, per step time: 2858.192 ms\n", "epoch: 92 step: 67, loss is 0.12776905\n", "Train epoch time: 218376.797 ms, per step time: 3259.355 ms\n", "epoch: 93 step: 67, loss is 0.12682373\n", "Train epoch time: 217263.432 ms, per step time: 3242.738 ms\n", "epoch: 94 step: 67, loss is 0.12594032\n", "Train epoch time: 217970.325 ms, per step time: 3253.288 ms\n", "epoch: 95 step: 67, loss is 0.12149178\n", "Train epoch time: 217401.066 ms, per step time: 3244.792 ms\n", "epoch: 96 step: 67, loss is 0.12223453\n", "Train epoch time: 218616.899 ms, per step time: 3265.344 ms\n", "epoch: 97 step: 67, loss is 0.12046164\n", "Train epoch time: 218616.899 ms, per step time: 3263.949 ms\n", "epoch: 98 step: 67, loss is 0.1172382\n", "Train epoch time: 216666.521 ms, per step time: 3233.829 ms\n", "epoch: 99 step: 67, loss is 0.11799482\n", "Train epoch time: 218090.233 ms, per step time: 3255.078 ms\n", "epoch: 100 step: 67, loss is 0.11112012\n", "Train epoch time: 218108.888 ms, per step time: 3255.357 ms\n", "2023-12-29 10:00:44,043 - forecast.py[line:223] - INFO: ================================Start Evaluation================================\n", "2023-12-29 10:02:59,291 - forecast.py[line:241] - INFO: test dataset size: 1\n", "2023-12-29 10:02:59,293 - forecast.py[line:191] - INFO: t = 6 hour: \n", "2023-12-29 10:02:59,294 - forecast.py[line:202] - INFO: RMSE of Z500: 159.26790471459077, T2m: 1.7593914514223792, T850: 1.2225771108909576, U10: 1.3952338408157166\n", "2023-12-29 10:02:59,295 - forecast.py[line:203] - INFO: ACC of Z500: 0.996888905697735, T2m: 0.9882202464019967, T850: 0.994542681351491, U10: 0.9697411543132562\n", "2023-12-29 10:02:59,297 - forecast.py[line:191] - INFO: t = 72 hour: \n", "2023-12-29 10:02:59,298 - forecast.py[line:202] - INFO: RMSE of Z500: 937.2960233810791, T2m: 5.177728653933931, T850: 4.831667457069809, U10: 5.30111109022694\n", "2023-12-29 10:02:59,299 - forecast.py[line:203] - INFO: ACC of Z500: 0.9542952919181137, T2m: 0.9557775651851869, T850: 0.9371537322317006, U10: 0.5895038993694246\n", "2023-12-29 10:02:59,300 - forecast.py[line:191] - INFO: t = 120 hour: \n", "2023-12-29 10:02:59,301 - forecast.py[line:202] - INFO: RMSE of Z500: 1200.9140481697198, T2m: 6.913749261896835, T850: 6.530332262562704, U10: 6.3855645042672835\n", "2023-12-29 10:02:59,303 - forecast.py[line:203] - INFO: ACC of Z500: 0.9257611031529911, T2m: 0.9197160039098073, T850: 0.8867113860499101, U10: 0.47483364671406136\n", "2023-12-29 10:02:59,304 - forecast.py[line:256] - INFO: ================================End Evaluation================================\n" ] } ], "source": [ "trainer = FuXiTrainer(config, fuxi_model, loss_cell, logger_obj)\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "c58c2cb7", "metadata": {}, "source": [ "## 模型评估和可视化" ] }, { "cell_type": "markdown", "id": "fdcb23c9", "metadata": {}, "source": [ "完成训练后,我们使用第100个ckpt进行推理。下述展示了预测值、地表和它们之间的误差可视化。" ] }, { "cell_type": "code", "execution_count": 7, "id": "0e582671", "metadata": {}, "outputs": [], "source": [ "params = load_checkpoint('./FuXi_depths_18_in_channels_96_out_channels_192_recompute_True_adam_oop/ckpt/step_1/FuXi-100_67.ckpt')\n", "load_param_into_net(fuxi_model, params)\n", "inference_module = InferenceModule(fuxi_model, config, logger_obj)\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "6c11a8b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] } ], "source": [ "data_params = config.get(\"data\")\n", "test_dataset_generator = Era5Data(data_params=data_params, run_mode='test')\n", "test_dataset = Dataset(test_dataset_generator, distribute=False,\n", " num_workers=data_params.get('num_workers'), shuffle=False)\n", "test_dataset = test_dataset.create_dataset(data_params.get('batch_size'))\n", "data = next(test_dataset.create_dict_iterator())\n", "inputs = data['inputs']\n", "labels = data['labels']" ] }, { "cell_type": "code", "execution_count": 9, "id": "7b5622f0", "metadata": {}, "outputs": [], "source": [ "labels = labels[..., 0, :, :]\n", "labels = labels.transpose(0, 2, 1)\n", "labels = labels.reshape(labels.shape[0], labels.shape[1], data_params.get(\"h_size\"), data_params.get(\"w_size\")).asnumpy()\n", "\n", "pred = inference_module.forecast(inputs)\n", "pred = pred[0].transpose(1, 0)\n", "pred = pred.reshape(pred.shape[0], data_params.get(\"h_size\"), data_params.get(\"w_size\")).asnumpy()\n", "pred = np.expand_dims(pred, axis=0)" ] }, { "cell_type": "code", "execution_count": 11, "id": "62a63323", "metadata": {}, "outputs": [], "source": [ "def plt_key_info_comparison(pred, label, root_dir):\n", " \"\"\" Visualize the comparison of forecast results \"\"\"\n", " std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n", " mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n", " std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n", " mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n", "\n", " plt.figure(num='e_imshow', figsize=(100, 50))\n", "\n", " plt.subplot(4, 3, 1)\n", " plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth') # Z500\n", " plt.subplot(4, 3, 2)\n", " plt_global_field_data(pred, 'Z500', std, mean, 'Pred') # Z500\n", " plt.subplot(4, 3, 3)\n", " plt_global_field_data(label - pred, 'Z500', std, mean, 'Error', is_error=True) # Z500\n", "\n", " plt.subplot(4, 3, 4)\n", " plt_global_field_data(label, 'T850', std, mean, 'Ground Truth') # T850\n", " plt.subplot(4, 3, 5)\n", " plt_global_field_data(pred, 'T850', std, mean, 'Pred') # T850\n", " plt.subplot(4, 3, 6)\n", " plt_global_field_data(label - pred, 'T850', std, mean, 'Error', is_error=True) # T850\n", "\n", " plt.subplot(4, 3, 7)\n", " plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True) # U10\n", " plt.subplot(4, 3, 8)\n", " plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True) # U10\n", " plt.subplot(4, 3, 9)\n", " plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True, is_error=True) # U10\n", "\n", " plt.subplot(4, 3, 10)\n", " plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True) # T2M\n", " plt.subplot(4, 3, 11)\n", " plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True) # T2M\n", " plt.subplot(4, 3, 12)\n", " plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True, is_error=True) # T2M\n", "\n", " plt.savefig(f'key_info_comparison1.png', bbox_inches='tight')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "d63c4fed-97f2-4399-a7f0-b17df5ae63f9", "metadata": {}, "outputs": [], "source": [ "plt_key_info_comparison(pred, labels, data_params.get('root_dir'))" ] }, { "cell_type": "markdown", "id": "950256f9-4cf6-4398-bdf3-40cc95077307", "metadata": {}, "source": [ "" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }