{ "cells": [ { "cell_type": "markdown", "id": "ecc50b48", "metadata": {}, "source": [ "# FuXi: Medium-range Global Weather Forecasting Based on Cascade Architecture\n", "\n", "[](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/medium-range/mindspore_fuxi.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/medium-range/mindspore_fuxi.py) [](https://gitee.com/mindspore/docs/blob/master/docs/mindearth/docs/source_en/medium-range/fuxi.ipynb)" ] }, { "cell_type": "markdown", "id": "0abd5fd1", "metadata": {}, "source": [ "## Overview\n", "\n", "FuXi is a data-driven global weather forecast model developed by researchers from Fudan University. It provides medium-term forecasts of key global weather indicators with a resolution of 0.25°. Equivalent to a spatial resolution of approximately 25 km x 25 km near the equator and a global grid of 720 x 1440 pixels in size. Compared with the previous ML-based weather forecast model, the FuXi model using cascade architecture achieved excellent results in [ECMWF](https://charts.ecmwf.int/products/plwww_3m_fc_aimodels_wp_mean?area=Northern%20Extra-tropics¶meter=Geopotential%20500hPa&score=Root%20mean%20square%20error).\n", "\n", "This tutorial introduces the research background and technical path of FuXi, and shows how to train and fast infer the model through MindSpore Earth. More information can be found in [paper](https://www.nature.com/articles/s41612-023-00512-1). The [ERA5_0_25_tiny400](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/ERA5_0_25_tiny400/) with a resolution of 0.25° is used to provide a detailed introduction to the operational process of this tutorial." ] }, { "cell_type": "markdown", "id": "39dacb90", "metadata": {}, "source": [ "## FuXi\n", "\n", "The basic Fuxi model architecture consists of three main components, as shown in the figure: Cube Embedding, U-Transformer, and fully connected layer. The input data combined the upper air and surface variables and created a data cube with dimensions 69×720×1440, with a time step as a step.\n", "\n", "The high-dimensional input data is reduced by combining space-time cube embedding and converted into C×180×360. The main purpose of cube embedding is to reduce the spatial-temporal dimension of input data and reduce redundant information. U-Transformer then processes the embedded data and makes predictions using a simple fully connected layer, and the output is first reshaped to 69×720×1440.\n", "\n", "\n", "\n", "- Cube Embedding\n", "\n", " In order to reduce the spatial and temporal dimensions of the input data and speed up the training process, the cube embedding method is applied.\n", "\n", " Specifically, a space-time cube embedding uses a three-dimensional (3D) convolutional layer, a convolution kernel and a stride length are respectively 2x4x4 (equivalent to $\\frac{T}{2}×\\frac{H}{2}×\\frac{W}{2}$), and a quantity of output channels is C. After the space-time cube embedding, Layer Norm is used to improve the stability of the training. The dimension of the resulting data cube is C×180×360.\n", "\n", "- U-Transformer\n", "\n", " The U-Transformer also includes the downsampling and upsampling blocks of the U-Net model. The downsampling block, referred to as the Down Block in the figure, reduces the data dimension to C×90×180, thereby minimizing the computational and memory requirements of self-attention computing. The down block consists of a 3×3 2D convolutional layer with a step of 2 and a residual block with two 3×3 convolution layers. This is followed by a group normalization (GN) layer and a Sigmoid weighted activation function(SiLU). SiLU weighted activation function $σ(x)×x$ is calculated by multiplying the Sigmoid function with its input.\n", "\n", " The up-sampling block is referred to as the up-block in the figure. It uses the same residual block as the down-block, and also includes a 2D deconvolution. The core is 2, and the step is 2. Up Block scales the data size back to Cx180x360. In addition, a jump connection is included to connect the output of the Down Block to the output of the Transformer Block before feeding to the Up Block.\n", "\n", " The intermediate structure is constructed from 18 repeated Swin Transformer blocks by using residual post-normalization instead of pre-normalization and scaled cosine attention instead of original dot product self-attention, Swin Transformer solves several problems that occur in training and applying large-scale Swin Transformer models, such as training instability.\n" ] }, { "cell_type": "markdown", "id": "da40cb1d", "metadata": {}, "source": [ "## Technology Path\n", "\n", "MindSpore Earth solves the problem as follows:\n", "\n", "1. Data Construction.\n", "2. Model Construction.\n", "3. Loss function.\n", "4. Model Training.\n", "5. Model Evaluation and Visualization." ] }, { "cell_type": "code", "execution_count": 2, "id": "52aa45e6", "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from mindspore import set_seed\n", "from mindspore import context\n", "from mindspore import load_checkpoint, load_param_into_net" ] }, { "cell_type": "markdown", "id": "062e7adc", "metadata": {}, "source": [ "The following `src` can be downloaded in [fuxi/src](https://gitee.com/mindspore/mindscience/tree/master/MindEarth/applications/medium-range/fuxi/src)." ] }, { "cell_type": "code", "execution_count": 3, "id": "1e8373e6", "metadata": {}, "outputs": [], "source": [ "from mindearth.utils import load_yaml_config, plt_global_field_data, make_dir\n", "from mindearth.data import Dataset, Era5Data\n", "\n", "from src import init_model, get_logger\n", "from src import MAELossForMultiLabel, FuXiTrainer, CustomWithLossCell, InferenceModule" ] }, { "cell_type": "code", "execution_count": 4, "id": "258d0acd", "metadata": {}, "outputs": [], "source": [ "set_seed(0)\n", "np.random.seed(0)\n", "random.seed(0)" ] }, { "cell_type": "markdown", "id": "18b670e4", "metadata": {}, "source": [ "You can get parameters of model, data and optimizer from [config](https://gitee.com/mindspore/mindscience/raw/master/MindEarth/applications/medium-range/fuxi/configs/FuXi.yaml)." ] }, { "cell_type": "code", "execution_count": 5, "id": "9ba71baf", "metadata": {}, "outputs": [], "source": [ "config = load_yaml_config(\"configs/FuXi.yaml\")\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=0)" ] }, { "cell_type": "markdown", "id": "4fa254aa", "metadata": {}, "source": [ "## Data Construction\n", "\n", "Download the statistic, training and validation dataset from [ERA5_0_25_tiny400](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/ERA5_0_25_tiny400/) to `./dataset`.\n", "\n", "The `./dataset` is hosted with the following directory structure:\n", "\n", "```markdown\n", ".\n", "├── statistic\n", "│ ├── mean.npy\n", "│ ├── mean_s.npy\n", "│ ├── std.npy\n", "│ ├── std_s.npy\n", "│ └── climate_0.25.npy\n", "├── train\n", "│ └── 2015\n", "├── train_static\n", "│ └── 2015\n", "├── train_surface\n", "│ └── 2015\n", "├── train_surface_static\n", "│ └── 2015\n", "├── valid\n", "│ └── 2016\n", "├── valid_static\n", "│ └── 2016\n", "├── valid_surface\n", "│ └── 2016\n", "├── valid_surface_static\n", "│ └── 2016\n", "```" ] }, { "cell_type": "markdown", "id": "25b0568c", "metadata": {}, "source": [ "## Model Construction\n", "\n", "Model initialization includes the number of Swin Transformer blocks and training parameters." ] }, { "cell_type": "code", "execution_count": 6, "id": "cee87d14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2024-01-29 12:34:41,485 - utils.py[line:91] - INFO: {'name': 'FuXi', 'depths': 18, 'in_channels': 96, 'out_channels': 192}\n", "2024-01-29 12:34:41,487 - utils.py[line:91] - INFO: {'name': 'era5', 'root_dir': '/data4/cq/ERA5_0_25/', 'feature_dims': 69, 'pressure_level_num': 13, 'level_feature_size': 5, 'surface_feature_size': 4, 'h_size': 720, 'w_size': 1440, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 6, 'valid_interval': 6, 'test_interval': 6, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2015, 2015], 'valid_period': [2016, 2016], 'test_period': [2017, 2017], 'num_workers': 1, 'grid_resolution': 0.25}\n", "2024-01-29 12:34:41,488 - utils.py[line:91] - INFO: {'name': 'adam', 'initial_lr': 0.00025, 'finetune_lr': 1e-05, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.0, 'loss_weight': 0.25, 'gamma': 0.5, 'epochs': 100}\n", "2024-01-29 12:34:41,489 - utils.py[line:91] - INFO: {'summary_dir': './summary', 'eval_interval': 10, 'save_checkpoint_steps': 10, 'keep_checkpoint_max': 50, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '/data3/cq'}\n", "2024-01-29 12:34:41,490 - utils.py[line:91] - INFO: {'name': 'oop', 'distribute': False, 'amp_level': 'O2', 'load_ckpt': True}\n" ] } ], "source": [ "make_dir(os.path.join(config['summary'][\"summary_dir\"], \"image\"))\n", "logger_obj = get_logger(config)\n", "fuxi_model = init_model(config, \"train\")" ] }, { "cell_type": "markdown", "id": "e818c80f", "metadata": {}, "source": [ "## Loss Function\n", "\n", "FuXi uses custom mean absolute error for model training." ] }, { "cell_type": "code", "execution_count": 10, "id": "67156e88", "metadata": {}, "outputs": [], "source": [ "data_params = config.get('data')\n", "optimizer_params = config.get('optimizer')\n", "loss_fn = MAELossForMultiLabel(data_params=data_params, optimizer_params=optimizer_params)\n", "loss_cell = CustomWithLossCell(backbone=fuxi_model, loss_fn=loss_fn)\n" ] }, { "cell_type": "markdown", "id": "19a03c13", "metadata": {}, "source": [ "## Model Training\n", "\n", "In this tutorial, we inherite the Trainer and override the `get_solver` member function to build custom loss function, and override `get_callback` member function to perform inference on the test dataset during the training process.\n", "\n", "MindSpore Earth provides training and inference interface for model training with `MindSpore` version >= 2.0.0." ] }, { "cell_type": "code", "execution_count": 15, "id": "51397466", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2023-12-29 08:46:01,280 - pretrain.py[line:215] - INFO: steps_per_epoch: 67\n", "epoch: 1 step: 67, loss is 0.37644267\n", "Train epoch time: 222879.994 ms, per step time: 3326.567 ms\n", "epoch: 2 step: 67, loss is 0.26096737\n", "Train epoch time: 216913.275 ms, per step time: 3237.512 ms\n", "epoch: 3 step: 67, loss is 0.2587443\n", "Train epoch time: 217870.765 ms, per step time: 3251.802 ms\n", "epoch: 4 step: 67, loss is 0.2280185\n", "Train epoch time: 218111.564 ms, per step time: 3255.396 ms\n", "epoch: 5 step: 67, loss is 0.20605856\n", "Train epoch time: 216881.674 ms, per step time: 3237.040 ms\n", "epoch: 6 step: 67, loss is 0.20178188\n", "Train epoch time: 218800.354 ms, per step time: 3265.677 ms\n", "epoch: 7 step: 67, loss is 0.21064804\n", "Train epoch time: 217554.571 ms, per step time: 3247.083 ms\n", "epoch: 8 step: 67, loss is 0.20392722\n", "Train epoch time: 217324.330 ms, per step time: 3243.647 ms\n", "epoch: 9 step: 67, loss is 0.19890495\n", "Train epoch time: 218374.032 ms, per step time: 3259.314 ms\n", "epoch: 10 step: 67, loss is 0.2064792\n", "Train epoch time: 217399.318 ms, per step time: 3244.766 ms\n", "2023-12-29 09:22:23,927 - forecast.py[line:223] - INFO: ================================Start Evaluation================================\n", "2023-12-29 09:24:51,246 - forecast.py[line:241] - INFO: test dataset size: 1\n", "2023-12-29 09:24:51,248 - forecast.py[line:191] - INFO: t = 6 hour: \n", "2023-12-29 09:24:51,250 - forecast.py[line:202] - INFO: RMSE of Z500: 313.254855370194, T2m: 2.911020155335285, T850: 1.6009748653510902, U10: 1.8822629694594444\n", "2023-12-29 09:24:51,251 - forecast.py[line:203] - INFO: ACC of Z500: 0.9950579247892839, T2m: 0.98743929296225, T850: 0.9930489273077082, U10: 0.9441216196638477\n", "2023-12-29 09:24:51,252 - forecast.py[line:191] - INFO: t = 72 hour: \n", "2023-12-29 09:24:51,253 - forecast.py[line:202] - INFO: RMSE of Z500: 1176.8557892319443, T2m: 7.344694139181644, T850: 6.165706260104667, U10: 5.953978905254709\n", "2023-12-29 09:24:51,254 - forecast.py[line:203] - INFO: ACC of Z500: 0.9271318752961824, T2m: 0.9236962494086007, T850: 0.9098796075852417, U10: 0.5003382663349598\n", "2023-12-29 09:24:51,255 - forecast.py[line:191] - INFO: t = 120 hour: \n", "2023-12-29 09:24:51,256 - forecast.py[line:202] - INFO: RMSE of Z500: 1732.662048442734, T2m: 9.891472332990181, T850: 8.233521390723434, U10: 7.434774900830313\n", "2023-12-29 09:24:51,256 - forecast.py[line:203] - INFO: ACC of Z500: 0.8421506711992445, T2m: 0.8468635778030965, T850: 0.8467625693884427, U10: 0.3787509969898105\n", "2023-12-29 09:24:51,257 - forecast.py[line:256] - INFO: ================================End Evaluation================================\n", "......\n", "epoch: 91 step: 67, loss is 0.13158562\n", "Train epoch time: 191498.866 ms, per step time: 2858.192 ms\n", "epoch: 92 step: 67, loss is 0.12776905\n", "Train epoch time: 218376.797 ms, per step time: 3259.355 ms\n", "epoch: 93 step: 67, loss is 0.12682373\n", "Train epoch time: 217263.432 ms, per step time: 3242.738 ms\n", "epoch: 94 step: 67, loss is 0.12594032\n", "Train epoch time: 217970.325 ms, per step time: 3253.288 ms\n", "epoch: 95 step: 67, loss is 0.12149178\n", "Train epoch time: 217401.066 ms, per step time: 3244.792 ms\n", "epoch: 96 step: 67, loss is 0.12223453\n", "Train epoch time: 218616.899 ms, per step time: 3265.344 ms\n", "epoch: 97 step: 67, loss is 0.12046164\n", "Train epoch time: 218616.899 ms, per step time: 3263.949 ms\n", "epoch: 98 step: 67, loss is 0.1172382\n", "Train epoch time: 216666.521 ms, per step time: 3233.829 ms\n", "epoch: 99 step: 67, loss is 0.11799482\n", "Train epoch time: 218090.233 ms, per step time: 3255.078 ms\n", "epoch: 100 step: 67, loss is 0.11112012\n", "Train epoch time: 218108.888 ms, per step time: 3255.357 ms\n", "2023-12-29 10:00:44,043 - forecast.py[line:223] - INFO: ================================Start Evaluation================================\n", "2023-12-29 10:02:59,291 - forecast.py[line:241] - INFO: test dataset size: 1\n", "2023-12-29 10:02:59,293 - forecast.py[line:191] - INFO: t = 6 hour: \n", "2023-12-29 10:02:59,294 - forecast.py[line:202] - INFO: RMSE of Z500: 159.26790471459077, T2m: 1.7593914514223792, T850: 1.2225771108909576, U10: 1.3952338408157166\n", "2023-12-29 10:02:59,295 - forecast.py[line:203] - INFO: ACC of Z500: 0.996888905697735, T2m: 0.9882202464019967, T850: 0.994542681351491, U10: 0.9697411543132562\n", "2023-12-29 10:02:59,297 - forecast.py[line:191] - INFO: t = 72 hour: \n", "2023-12-29 10:02:59,298 - forecast.py[line:202] - INFO: RMSE of Z500: 937.2960233810791, T2m: 5.177728653933931, T850: 4.831667457069809, U10: 5.30111109022694\n", "2023-12-29 10:02:59,299 - forecast.py[line:203] - INFO: ACC of Z500: 0.9542952919181137, T2m: 0.9557775651851869, T850: 0.9371537322317006, U10: 0.5895038993694246\n", "2023-12-29 10:02:59,300 - forecast.py[line:191] - INFO: t = 120 hour: \n", "2023-12-29 10:02:59,301 - forecast.py[line:202] - INFO: RMSE of Z500: 1200.9140481697198, T2m: 6.913749261896835, T850: 6.530332262562704, U10: 6.3855645042672835\n", "2023-12-29 10:02:59,303 - forecast.py[line:203] - INFO: ACC of Z500: 0.9257611031529911, T2m: 0.9197160039098073, T850: 0.8867113860499101, U10: 0.47483364671406136\n", "2023-12-29 10:02:59,304 - forecast.py[line:256] - INFO: ================================End Evaluation================================\n" ] } ], "source": [ "trainer = FuXiTrainer(config, fuxi_model, loss_cell, logger_obj)\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "c58c2cb7", "metadata": {}, "source": [ "## Model Evaluation and Visualization" ] }, { "cell_type": "markdown", "id": "fdcb23c9", "metadata": {}, "source": [ "After training, we use the 100th checkpoint for inference. The visualization of predictions, ground truth and their error is shown below." ] }, { "cell_type": "code", "execution_count": 7, "id": "0e582671", "metadata": {}, "outputs": [], "source": [ "params = load_checkpoint('./FuXi_depths_18_in_channels_96_out_channels_192_recompute_True_adam_oop/ckpt/step_1/FuXi-100_67.ckpt')\n", "load_param_into_net(fuxi_model, params)\n", "inference_module = InferenceModule(fuxi_model, config, logger_obj)\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "6c11a8b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] } ], "source": [ "data_params = config.get(\"data\")\n", "test_dataset_generator = Era5Data(data_params=data_params, run_mode='test')\n", "test_dataset = Dataset(test_dataset_generator, distribute=False,\n", " num_workers=data_params.get('num_workers'), shuffle=False)\n", "test_dataset = test_dataset.create_dataset(data_params.get('batch_size'))\n", "data = next(test_dataset.create_dict_iterator())\n", "inputs = data['inputs']\n", "labels = data['labels']" ] }, { "cell_type": "code", "execution_count": 9, "id": "7b5622f0", "metadata": {}, "outputs": [], "source": [ "labels = labels[..., 0, :, :]\n", "labels = labels.transpose(0, 2, 1)\n", "labels = labels.reshape(labels.shape[0], labels.shape[1], data_params.get(\"h_size\"), data_params.get(\"w_size\")).asnumpy()\n", "\n", "pred = inference_module.forecast(inputs)\n", "pred = pred[0].transpose(1, 0)\n", "pred = pred.reshape(pred.shape[0], data_params.get(\"h_size\"), data_params.get(\"w_size\")).asnumpy()\n", "pred = np.expand_dims(pred, axis=0)" ] }, { "cell_type": "code", "execution_count": 11, "id": "62a63323", "metadata": {}, "outputs": [], "source": [ "def plt_key_info_comparison(pred, label, root_dir):\n", " \"\"\" Visualize the comparison of forecast results \"\"\"\n", " std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n", " mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n", " std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n", " mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n", "\n", " plt.figure(num='e_imshow', figsize=(100, 50))\n", "\n", " plt.subplot(4, 3, 1)\n", " plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth') # Z500\n", " plt.subplot(4, 3, 2)\n", " plt_global_field_data(pred, 'Z500', std, mean, 'Pred') # Z500\n", " plt.subplot(4, 3, 3)\n", " plt_global_field_data(label - pred, 'Z500', std, mean, 'Error', is_error=True) # Z500\n", "\n", " plt.subplot(4, 3, 4)\n", " plt_global_field_data(label, 'T850', std, mean, 'Ground Truth') # T850\n", " plt.subplot(4, 3, 5)\n", " plt_global_field_data(pred, 'T850', std, mean, 'Pred') # T850\n", " plt.subplot(4, 3, 6)\n", " plt_global_field_data(label - pred, 'T850', std, mean, 'Error', is_error=True) # T850\n", "\n", " plt.subplot(4, 3, 7)\n", " plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True) # U10\n", " plt.subplot(4, 3, 8)\n", " plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True) # U10\n", " plt.subplot(4, 3, 9)\n", " plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True, is_error=True) # U10\n", "\n", " plt.subplot(4, 3, 10)\n", " plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True) # T2M\n", " plt.subplot(4, 3, 11)\n", " plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True) # T2M\n", " plt.subplot(4, 3, 12)\n", " plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True, is_error=True) # T2M\n", "\n", " plt.savefig(f'key_info_comparison1.png', bbox_inches='tight')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd9062ea-6705-4968-9b81-3181cb5167e2", "metadata": {}, "outputs": [], "source": [ "plt_key_info_comparison(pred, labels, data_params.get('root_dir'))" ] }, { "cell_type": "markdown", "id": "be39650e-73f1-40a6-9e49-196d4cfee973", "metadata": {}, "source": [ "" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }