{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ecc50b48",
   "metadata": {},
   "source": [
    "# FuXi: Medium-range Global Weather Forecasting Based on Cascade Architecture\n",
    "\n",
    "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/medium-range/mindspore_fuxi.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/medium-range/mindspore_fuxi.py) [![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindearth/docs/source_en/medium-range/fuxi.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0abd5fd1",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "FuXi is a data-driven global weather forecast model developed by researchers from Fudan University. It provides medium-term forecasts of key global weather indicators with a resolution of 0.25°. Equivalent to a spatial resolution of approximately 25 km x 25 km near the equator and a global grid of 720 x 1440 pixels in size. Compared with the previous ML-based weather forecast model, the FuXi model using cascade architecture achieved excellent results in [ECMWF](https://charts.ecmwf.int/products/plwww_3m_fc_aimodels_wp_mean?area=Northern%20Extra-tropics&parameter=Geopotential%20500hPa&score=Root%20mean%20square%20error).\n",
    "\n",
    "This tutorial introduces the research background and technical path of FuXi, and shows how to train and fast infer the model through MindSpore Earth. More information can be found in [paper](https://www.nature.com/articles/s41612-023-00512-1). The [ERA5_0_25_tiny400](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/ERA5_0_25_tiny400/) with a resolution of 0.25° is used to provide a detailed introduction to the operational process of this tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39dacb90",
   "metadata": {},
   "source": [
    "## FuXi\n",
    "\n",
    "The basic Fuxi model architecture consists of three main components, as shown in the figure: Cube Embedding, U-Transformer, and fully connected layer. The input data combined the upper air and surface variables and created a data cube with dimensions 69×720×1440, with a time step as a step.\n",
    "\n",
    "The high-dimensional input data is reduced by combining space-time cube embedding and converted into C×180×360. The main purpose of cube embedding is to reduce the spatial-temporal dimension of input data and reduce redundant information. U-Transformer then processes the embedded data and makes predictions using a simple fully connected layer, and the output is first reshaped to 69×720×1440.\n",
    "\n",
    "![FuXi_model](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindearth/docs/source_en/medium-range/images/FuXi_model.png)\n",
    "\n",
    "- Cube Embedding\n",
    "\n",
    "  In order to reduce the spatial and temporal dimensions of the input data and speed up the training process, the cube embedding method is applied.\n",
    "\n",
    "  Specifically, a space-time cube embedding uses a three-dimensional (3D) convolutional layer, a convolution kernel and a stride length are respectively 2x4x4 (equivalent to $\\frac{T}{2}×\\frac{H}{2}×\\frac{W}{2}$), and a quantity of output channels is C. After the space-time cube embedding, Layer Norm is used to improve the stability of the training. The dimension of the resulting data cube is C×180×360.\n",
    "\n",
    "- U-Transformer\n",
    "\n",
    "  The U-Transformer also includes the downsampling and upsampling blocks of the U-Net model. The downsampling block, referred to as the Down Block in the figure, reduces the data dimension to C×90×180, thereby minimizing the computational and memory requirements of self-attention computing. The down block consists of a 3×3 2D convolutional layer with a step of 2 and a residual block with two 3×3 convolution layers. This is followed by a group normalization (GN) layer and a Sigmoid weighted activation function(SiLU). SiLU weighted activation function $σ(x)×x$ is calculated by multiplying the Sigmoid function with its input.\n",
    "\n",
    "  The up-sampling block is referred to as the up-block in the figure. It uses the same residual block as the down-block, and also includes a 2D deconvolution. The core is 2, and the step is 2. Up Block scales the data size back to Cx180x360. In addition, a jump connection is included to connect the output of the Down Block to the output of the Transformer Block before feeding to the Up Block.\n",
    "\n",
    "  The intermediate structure is constructed from 18 repeated Swin Transformer blocks by using residual post-normalization instead of pre-normalization and scaled cosine attention instead of original dot product self-attention, Swin Transformer solves several problems that occur in training and applying large-scale Swin Transformer models, such as training instability.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da40cb1d",
   "metadata": {},
   "source": [
    "## Technology Path\n",
    "\n",
    "MindSpore Earth solves the problem as follows:\n",
    "\n",
    "1. Data Construction.\n",
    "2. Model Construction.\n",
    "3. Loss function.\n",
    "4. Model Training.\n",
    "5. Model Evaluation and Visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "52aa45e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from mindspore import set_seed\n",
    "from mindspore import context\n",
    "from mindspore import load_checkpoint, load_param_into_net"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062e7adc",
   "metadata": {},
   "source": [
    "The following `src` can be downloaded in [fuxi/src](https://gitee.com/mindspore/mindscience/tree/master/MindEarth/applications/medium-range/fuxi/src)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1e8373e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindearth.utils import load_yaml_config, plt_global_field_data, make_dir\n",
    "from mindearth.data import Dataset, Era5Data\n",
    "\n",
    "from src import init_model, get_logger\n",
    "from src import MAELossForMultiLabel, FuXiTrainer, CustomWithLossCell, InferenceModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "258d0acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(0)\n",
    "np.random.seed(0)\n",
    "random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18b670e4",
   "metadata": {},
   "source": [
    "You can get parameters of model, data and optimizer from [config](https://gitee.com/mindspore/mindscience/raw/master/MindEarth/applications/medium-range/fuxi/configs/FuXi.yaml)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9ba71baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = load_yaml_config(\"configs/FuXi.yaml\")\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fa254aa",
   "metadata": {},
   "source": [
    "## Data Construction\n",
    "\n",
    "Download the statistic, training and validation dataset from [ERA5_0_25_tiny400](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/ERA5_0_25_tiny400/) to `./dataset`.\n",
    "\n",
    "The `./dataset` is hosted with the following directory structure:\n",
    "\n",
    "```markdown\n",
    ".\n",
    "├── statistic\n",
    "│   ├── mean.npy\n",
    "│   ├── mean_s.npy\n",
    "│   ├── std.npy\n",
    "│   ├── std_s.npy\n",
    "│   └── climate_0.25.npy\n",
    "├── train\n",
    "│   └── 2015\n",
    "├── train_static\n",
    "│   └── 2015\n",
    "├── train_surface\n",
    "│   └── 2015\n",
    "├── train_surface_static\n",
    "│   └── 2015\n",
    "├── valid\n",
    "│   └── 2016\n",
    "├── valid_static\n",
    "│   └── 2016\n",
    "├── valid_surface\n",
    "│   └── 2016\n",
    "├── valid_surface_static\n",
    "│   └── 2016\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25b0568c",
   "metadata": {},
   "source": [
    "## Model Construction\n",
    "\n",
    "Model initialization includes the number of Swin Transformer blocks and training parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cee87d14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-01-29 12:34:41,485 - utils.py[line:91] - INFO: {'name': 'FuXi', 'depths': 18, 'in_channels': 96, 'out_channels': 192}\n",
      "2024-01-29 12:34:41,487 - utils.py[line:91] - INFO: {'name': 'era5', 'root_dir': '/data4/cq/ERA5_0_25/', 'feature_dims': 69, 'pressure_level_num': 13, 'level_feature_size': 5, 'surface_feature_size': 4, 'h_size': 720, 'w_size': 1440, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 6, 'valid_interval': 6, 'test_interval': 6, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2015, 2015], 'valid_period': [2016, 2016], 'test_period': [2017, 2017], 'num_workers': 1, 'grid_resolution': 0.25}\n",
      "2024-01-29 12:34:41,488 - utils.py[line:91] - INFO: {'name': 'adam', 'initial_lr': 0.00025, 'finetune_lr': 1e-05, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.0, 'loss_weight': 0.25, 'gamma': 0.5, 'epochs': 100}\n",
      "2024-01-29 12:34:41,489 - utils.py[line:91] - INFO: {'summary_dir': './summary', 'eval_interval': 10, 'save_checkpoint_steps': 10, 'keep_checkpoint_max': 50, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '/data3/cq'}\n",
      "2024-01-29 12:34:41,490 - utils.py[line:91] - INFO: {'name': 'oop', 'distribute': False, 'amp_level': 'O2', 'load_ckpt': True}\n"
     ]
    }
   ],
   "source": [
    "make_dir(os.path.join(config['summary'][\"summary_dir\"], \"image\"))\n",
    "logger_obj = get_logger(config)\n",
    "fuxi_model = init_model(config, \"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e818c80f",
   "metadata": {},
   "source": [
    "## Loss Function\n",
    "\n",
    "FuXi uses custom mean absolute error for model training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "67156e88",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_params = config.get('data')\n",
    "optimizer_params = config.get('optimizer')\n",
    "loss_fn = MAELossForMultiLabel(data_params=data_params, optimizer_params=optimizer_params)\n",
    "loss_cell = CustomWithLossCell(backbone=fuxi_model, loss_fn=loss_fn)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19a03c13",
   "metadata": {},
   "source": [
    "## Model Training\n",
    "\n",
    "In this tutorial, we inherite the Trainer and override the `get_solver` member function to build custom loss function, and override `get_callback` member function to perform inference on the test dataset during the training process.\n",
    "\n",
    "MindSpore Earth provides training and inference interface for model training with `MindSpore` version >= 2.0.0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "51397466",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-12-29 08:46:01,280 - pretrain.py[line:215] - INFO: steps_per_epoch: 67\n",
      "epoch: 1 step: 67, loss is 0.37644267\n",
      "Train epoch time: 222879.994 ms, per step time: 3326.567 ms\n",
      "epoch: 2 step: 67, loss is 0.26096737\n",
      "Train epoch time: 216913.275 ms, per step time: 3237.512 ms\n",
      "epoch: 3 step: 67, loss is 0.2587443\n",
      "Train epoch time: 217870.765 ms, per step time: 3251.802 ms\n",
      "epoch: 4 step: 67, loss is 0.2280185\n",
      "Train epoch time: 218111.564 ms, per step time: 3255.396 ms\n",
      "epoch: 5 step: 67, loss is 0.20605856\n",
      "Train epoch time: 216881.674 ms, per step time: 3237.040 ms\n",
      "epoch: 6 step: 67, loss is 0.20178188\n",
      "Train epoch time: 218800.354 ms, per step time: 3265.677 ms\n",
      "epoch: 7 step: 67, loss is 0.21064804\n",
      "Train epoch time: 217554.571 ms, per step time: 3247.083 ms\n",
      "epoch: 8 step: 67, loss is 0.20392722\n",
      "Train epoch time: 217324.330 ms, per step time: 3243.647 ms\n",
      "epoch: 9 step: 67, loss is 0.19890495\n",
      "Train epoch time: 218374.032 ms, per step time: 3259.314 ms\n",
      "epoch: 10 step: 67, loss is 0.2064792\n",
      "Train epoch time: 217399.318 ms, per step time: 3244.766 ms\n",
      "2023-12-29 09:22:23,927 - forecast.py[line:223] - INFO: ================================Start Evaluation================================\n",
      "2023-12-29 09:24:51,246 - forecast.py[line:241] - INFO: test dataset size: 1\n",
      "2023-12-29 09:24:51,248 - forecast.py[line:191] - INFO: t = 6 hour: \n",
      "2023-12-29 09:24:51,250 - forecast.py[line:202] - INFO:  RMSE of Z500: 313.254855370194, T2m: 2.911020155335285, T850: 1.6009748653510902, U10: 1.8822629694594444\n",
      "2023-12-29 09:24:51,251 - forecast.py[line:203] - INFO:  ACC  of Z500: 0.9950579247892839, T2m: 0.98743929296225, T850: 0.9930489273077082, U10: 0.9441216196638477\n",
      "2023-12-29 09:24:51,252 - forecast.py[line:191] - INFO: t = 72 hour: \n",
      "2023-12-29 09:24:51,253 - forecast.py[line:202] - INFO:  RMSE of Z500: 1176.8557892319443, T2m: 7.344694139181644, T850: 6.165706260104667, U10: 5.953978905254709\n",
      "2023-12-29 09:24:51,254 - forecast.py[line:203] - INFO:  ACC  of Z500: 0.9271318752961824, T2m: 0.9236962494086007, T850: 0.9098796075852417, U10: 0.5003382663349598\n",
      "2023-12-29 09:24:51,255 - forecast.py[line:191] - INFO: t = 120 hour: \n",
      "2023-12-29 09:24:51,256 - forecast.py[line:202] - INFO:  RMSE of Z500: 1732.662048442734, T2m: 9.891472332990181, T850: 8.233521390723434, U10: 7.434774900830313\n",
      "2023-12-29 09:24:51,256 - forecast.py[line:203] - INFO:  ACC  of Z500: 0.8421506711992445, T2m: 0.8468635778030965, T850: 0.8467625693884427, U10: 0.3787509969898105\n",
      "2023-12-29 09:24:51,257 - forecast.py[line:256] - INFO: ================================End Evaluation================================\n",
      "......\n",
      "epoch: 91 step: 67, loss is 0.13158562\n",
      "Train epoch time: 191498.866 ms, per step time: 2858.192 ms\n",
      "epoch: 92 step: 67, loss is 0.12776905\n",
      "Train epoch time: 218376.797 ms, per step time: 3259.355 ms\n",
      "epoch: 93 step: 67, loss is 0.12682373\n",
      "Train epoch time: 217263.432 ms, per step time: 3242.738 ms\n",
      "epoch: 94 step: 67, loss is 0.12594032\n",
      "Train epoch time: 217970.325 ms, per step time: 3253.288 ms\n",
      "epoch: 95 step: 67, loss is 0.12149178\n",
      "Train epoch time: 217401.066 ms, per step time: 3244.792 ms\n",
      "epoch: 96 step: 67, loss is 0.12223453\n",
      "Train epoch time: 218616.899 ms, per step time: 3265.344 ms\n",
      "epoch: 97 step: 67, loss is 0.12046164\n",
      "Train epoch time: 218616.899 ms, per step time: 3263.949 ms\n",
      "epoch: 98 step: 67, loss is 0.1172382\n",
      "Train epoch time: 216666.521 ms, per step time: 3233.829 ms\n",
      "epoch: 99 step: 67, loss is 0.11799482\n",
      "Train epoch time: 218090.233 ms, per step time: 3255.078 ms\n",
      "epoch: 100 step: 67, loss is 0.11112012\n",
      "Train epoch time: 218108.888 ms, per step time: 3255.357 ms\n",
      "2023-12-29 10:00:44,043 - forecast.py[line:223] - INFO: ================================Start Evaluation================================\n",
      "2023-12-29 10:02:59,291 - forecast.py[line:241] - INFO: test dataset size: 1\n",
      "2023-12-29 10:02:59,293 - forecast.py[line:191] - INFO: t = 6 hour: \n",
      "2023-12-29 10:02:59,294 - forecast.py[line:202] - INFO:  RMSE of Z500: 159.26790471459077, T2m: 1.7593914514223792, T850: 1.2225771108909576, U10: 1.3952338408157166\n",
      "2023-12-29 10:02:59,295 - forecast.py[line:203] - INFO:  ACC  of Z500: 0.996888905697735, T2m: 0.9882202464019967, T850: 0.994542681351491, U10: 0.9697411543132562\n",
      "2023-12-29 10:02:59,297 - forecast.py[line:191] - INFO: t = 72 hour: \n",
      "2023-12-29 10:02:59,298 - forecast.py[line:202] - INFO:  RMSE of Z500: 937.2960233810791, T2m: 5.177728653933931, T850: 4.831667457069809, U10: 5.30111109022694\n",
      "2023-12-29 10:02:59,299 - forecast.py[line:203] - INFO:  ACC  of Z500: 0.9542952919181137, T2m: 0.9557775651851869, T850: 0.9371537322317006, U10: 0.5895038993694246\n",
      "2023-12-29 10:02:59,300 - forecast.py[line:191] - INFO: t = 120 hour: \n",
      "2023-12-29 10:02:59,301 - forecast.py[line:202] - INFO:  RMSE of Z500: 1200.9140481697198, T2m: 6.913749261896835, T850: 6.530332262562704, U10: 6.3855645042672835\n",
      "2023-12-29 10:02:59,303 - forecast.py[line:203] - INFO:  ACC  of Z500: 0.9257611031529911, T2m: 0.9197160039098073, T850: 0.8867113860499101, U10: 0.47483364671406136\n",
      "2023-12-29 10:02:59,304 - forecast.py[line:256] - INFO: ================================End Evaluation================================\n"
     ]
    }
   ],
   "source": [
    "trainer = FuXiTrainer(config, fuxi_model, loss_cell, logger_obj)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c58c2cb7",
   "metadata": {},
   "source": [
    "## Model Evaluation and Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdcb23c9",
   "metadata": {},
   "source": [
    "After training, we use the 100th checkpoint for inference. The visualization of predictions, ground truth and their error is shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0e582671",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = load_checkpoint('./FuXi_depths_18_in_channels_96_out_channels_192_recompute_True_adam_oop/ckpt/step_1/FuXi-100_67.ckpt')\n",
    "load_param_into_net(fuxi_model, params)\n",
    "inference_module = InferenceModule(fuxi_model, config, logger_obj)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6c11a8b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": []
    }
   ],
   "source": [
    "data_params = config.get(\"data\")\n",
    "test_dataset_generator = Era5Data(data_params=data_params, run_mode='test')\n",
    "test_dataset = Dataset(test_dataset_generator, distribute=False,\n",
    "                       num_workers=data_params.get('num_workers'), shuffle=False)\n",
    "test_dataset = test_dataset.create_dataset(data_params.get('batch_size'))\n",
    "data = next(test_dataset.create_dict_iterator())\n",
    "inputs = data['inputs']\n",
    "labels = data['labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7b5622f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = labels[..., 0, :, :]\n",
    "labels = labels.transpose(0, 2, 1)\n",
    "labels = labels.reshape(labels.shape[0], labels.shape[1], data_params.get(\"h_size\"), data_params.get(\"w_size\")).asnumpy()\n",
    "\n",
    "pred = inference_module.forecast(inputs)\n",
    "pred = pred[0].transpose(1, 0)\n",
    "pred = pred.reshape(pred.shape[0], data_params.get(\"h_size\"), data_params.get(\"w_size\")).asnumpy()\n",
    "pred = np.expand_dims(pred, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "62a63323",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plt_key_info_comparison(pred, label, root_dir):\n",
    "    \"\"\" Visualize the comparison of forecast results \"\"\"\n",
    "    std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n",
    "    mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n",
    "    std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n",
    "    mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n",
    "\n",
    "    plt.figure(num='e_imshow', figsize=(100, 50))\n",
    "\n",
    "    plt.subplot(4, 3, 1)\n",
    "    plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth')  # Z500\n",
    "    plt.subplot(4, 3, 2)\n",
    "    plt_global_field_data(pred, 'Z500', std, mean, 'Pred')  # Z500\n",
    "    plt.subplot(4, 3, 3)\n",
    "    plt_global_field_data(label - pred, 'Z500', std, mean, 'Error', is_error=True)  # Z500\n",
    "\n",
    "    plt.subplot(4, 3, 4)\n",
    "    plt_global_field_data(label, 'T850', std, mean, 'Ground Truth')  # T850\n",
    "    plt.subplot(4, 3, 5)\n",
    "    plt_global_field_data(pred, 'T850', std, mean, 'Pred')  # T850\n",
    "    plt.subplot(4, 3, 6)\n",
    "    plt_global_field_data(label - pred, 'T850', std, mean, 'Error', is_error=True)  # T850\n",
    "\n",
    "    plt.subplot(4, 3, 7)\n",
    "    plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True)  # U10\n",
    "    plt.subplot(4, 3, 8)\n",
    "    plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True)  # U10\n",
    "    plt.subplot(4, 3, 9)\n",
    "    plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True, is_error=True)  # U10\n",
    "\n",
    "    plt.subplot(4, 3, 10)\n",
    "    plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True)  # T2M\n",
    "    plt.subplot(4, 3, 11)\n",
    "    plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True)  # T2M\n",
    "    plt.subplot(4, 3, 12)\n",
    "    plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True, is_error=True)  # T2M\n",
    "\n",
    "    plt.savefig(f'key_info_comparison1.png', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd9062ea-6705-4968-9b81-3181cb5167e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_key_info_comparison(pred, labels, data_params.get('root_dir'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be39650e-73f1-40a6-9e49-196d4cfee973",
   "metadata": {},
   "source": [
    "![](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindearth/docs/source_en/medium-range/images/fuxi_result.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}