{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# FourCastNet: 基于傅里叶神经算子的全球中期天气预报\n",
    "\n",
    "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth/zh_cn/medium-range/mindspore_FourCastNet.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth/zh_cn/medium-range/mindspore_FourCastNet.py) [![ViewSource](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.2/docs/mindearth/docs/source_en/medium-range/FourCastNet.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 概述\n",
    "\n",
    "FourCastNet(傅里叶预测神经网络)是一个数据驱动的全球天气预报模型,由NVIDIA、劳伦斯伯克利国家实验室、密歇根大学安阿伯和赖斯大学的研究人员开发。它提供了关键全球天气指标的中期预报,分辨率为0.25°。相当于赤道附近约30公里x30公里的空间分辨率和大小为720 x 1440像素的全球网格。与传统的NWP模型相比,该模型的预报速度提高了45000倍,2秒内生成一周的天气预报,预报精度与最先进的数值天气预报模型ECMWF综合预报系统(IFS)相当。这是第一个可以直接与IFS系统进行比较的AI天气预报模型。\n",
    "\n",
    "本教程介绍了FourCastNet的研究背景和技术路径,并展示了如何通过MindFlow训练和快速推断模型。更多详细信息见[文章](https://arxiv.org/abs/2202.11214)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 技术路径\n",
    "\n",
    "MindEarth求解该问题的具体流程如下:\n",
    "\n",
    "1. 创建数据集\n",
    "2. 模型构建\n",
    "3. 损失函数\n",
    "4. 模型训练\n",
    "5. 模型验证和可视化"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## FourCastNet\n",
    "\n",
    "为了实现高分辨率预测,FourCastNet使用AFNO模型。该模型网络体系结构是为高分辨率输入而设计的,以ViT为骨干网,并结合了李宗义等人提出的傅里叶神经算子(FNO)。该模型学习函数空间之间的映射,从而求解一系列非线性偏微分方程。\n",
    "\n",
    "Vision Transforme(ViT)体系结构及其变体在过去几年中已成为计算机视觉中最先进的技术,在许多任务中表现出卓越的性能。这种性能主要归因于网络中的多头自注意机制,它使网络中每一层特征之间的全局建模。然而,模型在训练和推理期间的计算复杂度随着令牌(或patches)数量的增加而二次增加,模型计算复杂度随着输入分辨率的增加而爆炸性增加。\n",
    "\n",
    "AFNO模型的独创性在于,它将空间混合操作转换为傅里叶变换,混合不同令牌的信息,将特征从空域转换为频域,并对频域特征应用全局可学习滤波器。空间混合复杂度有效地降低到O(NlogN),其中N是token的数量。\n",
    "\n",
    "FourCastNet网络架构如下图所示。\n",
    "\n",
    "![AFNO model](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/docs/mindearth/docs/source_zh_cn/medium-range/images/AFNO.png)\n",
    "\n",
    "模型训练包括三个步骤:\n",
    "\n",
    "1.预训练:如上图(a)所示,在预训练步骤中,使用训练数据集以监督的方式训练AFNO模型,以学习从X(k)到X(k + 1)的映射。\n",
    "\n",
    "2.微调:如上图(b)所示,模型首先从X(k)预测X(k + 1),然后使用X(k + 1)作为输入预测X(k + 2)。然后,通过从X(k + 1)和X(k + 2)的预测值计算损失函数值,利用两个损失函数值的总和优化模型。\n",
    "\n",
    "3.降水预报:如上文(c)所示,降水预报由主干模型后面的单独模型拼接而成。该方法将降水预测任务与基本气象要素解耦。另一方面,训练好的降水模型也可以与其他预测模型(传统的NWP等)结合使用。\n",
    "\n",
    "本教程主要实现模型预训练部分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from mindspore import context\n",
    "from mindspore import load_checkpoint, load_param_into_net\n",
    "\n",
    "from mindearth.utils import load_yaml_config, create_logger, plt_global_field_data\n",
    "from mindearth.module import Trainer\n",
    "from mindearth.data import Dataset, Era5Data\n",
    "from mindearth import RelativeRMSELoss\n",
    "from mindearth.cell import AFNONet"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " `src` 文件可以从[FourCastNet/src](https://gitee.com/mindspore/mindscience/tree/r0.5/MindEarth/applications/medium-range/fourcastnet/src)下载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from src.callback import EvaluateCallBack, InferenceModule\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "model、data和optimizer的参数可以通过加载yaml文件获取([FourCastNet.yaml](https://gitee.com/mindspore/mindscience/blob/r0.5/MindEarth/applications/medium-range/fourcastnet/FourCastNet.yaml))。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "config = load_yaml_config('FourCastNet.yaml')\n",
    "config['model']['data_sink'] = True  # 是否使用data sink特性\n",
    "\n",
    "config['train']['distribute'] = False  # 是否执行分布式任务\n",
    "config['train']['amp_level'] = 'O2'  # 设置混合精度等级\n",
    "\n",
    "config['data']['num_workers'] = 1  # 设置并行计算的进程数量\n",
    "config['data']['grid_resolution'] = 1.4  # 设置气象分辨率参数\n",
    "\n",
    "config['optimizer']['epochs'] = 100  # 设置epoch数量\n",
    "config['optimizer']['finetune_epochs'] = 1  # 设置微调epoch数量\n",
    "config['optimizer']['warmup_epochs'] = 1  # 设置预热epoch的数量\n",
    "config['optimizer']['initial_lr'] = 0.0005  # 设置初始化学习率\n",
    "\n",
    "config['summary'][\"valid_frequency\"] = 10  # 设置验证的频率\n",
    "config['summary'][\"summary_dir\"] = './summary'  # 设置模型checkpoint的存储路径\n",
    "\n",
    "logger = create_logger(path=\"results.log\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建数据集\n",
    "\n",
    "在[dataset](https://download.mindspore.cn/mindscience/mindearth/dataset/WeatherBench_1.4_69/)路径下,下载正则化参数、训练数据集验证数据集到 `./dataset`目录。\n",
    "\n",
    "修改[FourCastNet.yaml](https://gitee.com/mindspore/mindscience/blob/r0.5/MindEarth/applications/medium-range/fourcastnet/FourCastNet.yaml)配置文件中的`root_dir`参数,该参数设置了数据集的路径。\n",
    "\n",
    "`./dataset`中的目录结构如下所示:\n",
    "\n",
    "``` markdown\n",
    ".\n",
    "├── statistic\n",
    "│   ├── mean.npy\n",
    "│   ├── mean_s.npy\n",
    "│   ├── std.npy\n",
    "│   └── std_s.npy\n",
    "├── train\n",
    "│   └── 2015\n",
    "├── train_static\n",
    "│   └── 2015\n",
    "├── train_surface\n",
    "│   └── 2015\n",
    "├── train_surface_static\n",
    "│   └── 2015\n",
    "├── valid\n",
    "│   └── 2016\n",
    "├── valid_static\n",
    "│   └── 2016\n",
    "├── valid_surface\n",
    "│   └── 2016\n",
    "├── valid_surface_static\n",
    "│   └── 2016\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 模型构建"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "加载相关的数据参数和模型参数,并完成AFNONet模型构建。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data_params = config['data']\n",
    "model_params = config['model']\n",
    "\n",
    "model = AFNONet(image_size=(data_params['h_size'], data_params['w_size']),\n",
    "                in_channels=data_params[\"feature_dims\"],\n",
    "                out_channels=data_params[\"feature_dims\"],\n",
    "                patch_size=data_params[\"patch_size\"],\n",
    "                encoder_depths=model_params[\"encoder_depths\"],\n",
    "                encoder_embed_dim=model_params[\"encoder_embed_dim\"],\n",
    "                mlp_ratio=model_params[\"mlp_ratio\"],\n",
    "                dropout_rate=model_params[\"dropout_rate\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 损失函数\n",
    "\n",
    "FourCastNet使用相对均方根误差进行模型训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "loss_fn = RelativeRMSELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 模型训练\n",
    "\n",
    "在本教程中,我们继承了Trainer并重写了get_callback成员函数,以便我们可以在训练过程中对测试数据集执行推理。\n",
    "\n",
    "对于MindSpore版本>= 1.8.1,我们可以使用函数式编程来训练神经网络。MindSpore Earth为模型训练提供了训练接口。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-09-07 02:26:03,143 - pretrain.py[line:211] - INFO: steps_per_epoch: 404\n"
     ]
    }
   ],
   "source": [
    "class FCNTrainer(Trainer):\n",
    "    def __init__(self, config, model, loss_fn, logger):\n",
    "        super(FCNTrainer, self).__init__(config, model, loss_fn, logger)\n",
    "        self.pred_cb = self.get_callback()\n",
    "\n",
    "    def get_callback(self):\n",
    "        pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)\n",
    "        return pred_cb\n",
    "\n",
    "trainer = FCNTrainer(config, model, loss_fn, logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 404, loss is 0.5348429\n",
      "Train epoch time: 136480.515 ms, per step time: 337.823 ms\n",
      "epoch: 2 step: 404, loss is 0.35937342\n",
      "Train epoch time: 60902.627 ms, per step time: 150.749 ms\n",
      "...\n",
      "epoch: 98 step: 404, loss is 0.15447393\n",
      "Train epoch time: 61055.706 ms, per step time: 151.128 ms\n",
      "epoch: 99 step: 404, loss is 0.15696357\n",
      "Train epoch time: 60850.156 ms, per step time: 150.619 ms\n",
      "epoch: 100 step: 404, loss is 0.15654306\n",
      "Train epoch time: 60944.369 ms, per step time: 150.852 ms\n",
      "2023-09-07 04:27:02,837 - forecast.py[line:209] - INFO: ================================Start Evaluation================================\n",
      "2023-09-07 04:28:25,277 - forecast.py[line:177] - INFO: t = 6 hour: \n",
      "2023-09-07 04:28:25,277 - forecast.py[line:188] - INFO:  RMSE of Z500: 154.07894852240838, T2m: 2.0995438696856965, T850: 1.3081689948838815, U10: 1.527248748050362\n",
      "2023-09-07 04:28:25,278 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9989880649296732, T2m: 0.9930711917863625, T850: 0.9954355203713009, U10: 0.9615764420500764\n",
      "2023-09-07 04:28:25,279 - forecast.py[line:177] - INFO: t = 72 hour: \n",
      "2023-09-07 04:28:25,279 - forecast.py[line:188] - INFO:  RMSE of Z500: 885.3778200063341, T2m: 4.586325958437852, T850: 4.2593739999338736, U10: 4.75655467109408\n",
      "2023-09-07 04:28:25,280 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9598951919101183, T2m: 0.9658168304842388, T850: 0.9501612262744354, U10: 0.6175327930007481\n",
      "2023-09-07 04:28:25,281 - forecast.py[line:177] - INFO: t = 120 hour: \n",
      "2023-09-07 04:28:25,281 - forecast.py[line:188] - INFO:  RMSE of Z500: 1291.3199606908572, T2m: 6.734047767054735, T850: 5.6420206614200294, U10: 5.637643311177468\n",
      "2023-09-07 04:28:25,282 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9150022892106006, T2m: 0.9294266102808937, T850: 0.9148957221265037, U10: 0.47971871343985495\n",
      "2023-09-07 04:28:25,283 - forecast.py[line:237] - INFO: ================================End Evaluation================================\n"
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 模型推理及可视化\n",
    "\n",
    "完成训练后,我们使用第100个ckpt进行推理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "pred_time_index = 0\n",
    "\n",
    "params = load_checkpoint('./summary/ckpt/step_1/FourCastNet_1-100_404.ckpt')\n",
    "load_param_into_net(model, params)\n",
    "inference_module = InferenceModule(model, config, logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plt_data(pred, label, root_dir, index=0):\n",
    "    \"\"\" Visualize the forecast results \"\"\"\n",
    "    std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n",
    "    mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n",
    "    std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n",
    "    mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n",
    "    pred, label = pred[index].asnumpy(), label.asnumpy()[..., index, :, :]\n",
    "    plt.figure(num='e_imshow', figsize=(100, 50), dpi=100)\n",
    "\n",
    "    plt.subplot(4, 3, 1)\n",
    "    plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth')  # Z500\n",
    "    plt.subplot(4, 3, 2)\n",
    "    plt_global_field_data(pred, 'Z500', std, mean, 'Pred')  # Z500\n",
    "    plt.subplot(4, 3, 3)\n",
    "    plt_global_field_data(label - pred, 'Z500', std, mean, 'Error')  # Z500\n",
    "\n",
    "    plt.subplot(4, 3, 4)\n",
    "    plt_global_field_data(label, 'T850', std, mean, 'Ground Truth')  # T850\n",
    "    plt.subplot(4, 3, 5)\n",
    "    plt_global_field_data(pred, 'T850', std, mean, 'Pred')  # T850\n",
    "    plt.subplot(4, 3, 6)\n",
    "    plt_global_field_data(label - pred, 'T850', std, mean, 'Error')  # T850\n",
    "\n",
    "    plt.subplot(4, 3, 7)\n",
    "    plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True)  # U10\n",
    "    plt.subplot(4, 3, 8)\n",
    "    plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True)  # U10\n",
    "    plt.subplot(4, 3, 9)\n",
    "    plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True)  # U10\n",
    "\n",
    "    plt.subplot(4, 3, 10)\n",
    "    plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True)  # T2M\n",
    "    plt.subplot(4, 3, 11)\n",
    "    plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True)  # T2M\n",
    "    plt.subplot(4, 3, 12)\n",
    "    plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True)  # T2M\n",
    "\n",
    "    plt.savefig(f'pred_result.png', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset_generator = Era5Data(data_params=config[\"data\"], run_mode='test')\n",
    "test_dataset = Dataset(test_dataset_generator, distribute=False,\n",
    "                       num_workers=config[\"data\"]['num_workers'], shuffle=False)\n",
    "test_dataset = test_dataset.create_dataset(config[\"data\"]['batch_size'])\n",
    "data = next(test_dataset.create_dict_iterator())\n",
    "inputs = data['inputs']\n",
    "labels = data['labels']\n",
    "pred = inference_module.forecast(inputs)\n",
    "plt_data(pred, labels, config['data']['root_dir'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下述展示了第100个ckpt的真实值、预测值和他们之间的误差可视化。\n",
    "\n",
    "![plot result](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/docs/mindearth/docs/source_zh_cn/medium-range/images/fno_result.png)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "16478c1492173c9a4f4847b8186328de7a4ca317afeafcd41bba7d71ba067560"
  },
  "kernelspec": {
   "display_name": "Python 3.9.16 64-bit ('lbk_ms10': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}