{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ViT-KNO: 基于Koopman神经算子的全球中期天气预报\n",
    "\n",
    "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth/zh_cn/medium-range/mindspore_vit_kno.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth/zh_cn/medium-range/mindspore_vit_kno.py) [![ViewSource](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.2/docs/mindearth/docs/source_zh_cn/medium-range/vit_kno.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "现代数据天气预报(Numerical Weather Prediction, NWP)可以追溯到1920年,其基于物理原理,整合了几代气象学者的成果经验,是各国气象部门所采用主流的天气预报方法。其中来自欧洲中期天气预报中心(ECMWF)的高分辨率综合系统模型(IFS)表现效果最佳。\n",
    "\n",
    "直到2022年英伟达研发了一种基于傅里叶神经网络的预测模型FourCastNet,它能以0.25°的分辨率生成全球关键天气指标的预测,这相当于赤道附近约30×30km的空间分辨率和720×1440像素的权重网格大小,与IFS系统一致。这项成果使得AI气象模型首次与传统的物理模型IFS进行直接比较。更多信息可参考:[\"FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators\"](https://arxiv.org/pdf/2202.11214.pdf)。\n",
    "\n",
    "但是基于傅里叶神经算子(Fourier Neural Operator, FNO)构建的预测模型FourCastNet在预测中长期天气时,变得不够准确和缺乏可解释性。ViT-KNO充分利用Vision Transformer结构和Koopman理论,学习Koopman Operator去预测非线性动力学系统,通过在线性结构中嵌入复杂的动力学去约束重建过程,ViT-KNO能够捕获复杂的非线性行为,同时保持模型轻量级和计算有效性。ViT-KNO有清晰的数学理论支撑,很好的克服了同类方法在数学和物理上可解释性和理论依据不足的问题。更多信息可参考:[\"KoopmanLab: machine learning for solving complex physics equations\"](https://arxiv.org/pdf/2301.01104.pdf)。\n",
    "\n",
    "## 技术路径\n",
    "\n",
    "MindSpore求解该问题的具体流程如下:\n",
    "\n",
    "1. 创建数据集\n",
    "2. 模型构建\n",
    "3. 损失函数\n",
    "4. 模型训练\n",
    "5. 模型验证和可视化\n",
    "\n",
    "## ViT-KNO\n",
    "\n",
    "ViT-KNO模型构架如下图所示,主要包含两个分支,上路分支负责结果预测,由Encoder模块,Koopman Layer模块,Decoder模块组成,其中Koopman Layer模块结构如虚线框所示,可重复堆叠;下路分支由Encoder模块,Decoder模块组成,负责输入信息的重构。\n",
    "\n",
    "![ViT-KNO](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/docs/mindearth/docs/source_zh_cn/medium-range/images/vit_kno.png \"Model\")\n",
    "\n",
    "模型的训练流程如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from mindspore import context, Model\n",
    "from mindspore import dtype as mstype\n",
    "from mindspore.train import load_checkpoint, load_param_into_net\n",
    "from mindspore.train import DynamicLossScaleManager\n",
    "\n",
    "from mindearth.cell import ViTKNO\n",
    "from mindearth.utils import load_yaml_config, create_logger, plt_global_field_data\n",
    "from mindearth.data import Dataset, Era5Data\n",
    "from mindearth.module import Trainer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " `src` 文件可以从[ViT-KNO/src](https://gitee.com/mindspore/mindscience/tree/r0.5/MindEarth/applications/medium-range/koopman_vit/src)下载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.callback import EvaluateCallBack, InferenceModule, Lploss, CustomWithLossCell\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "model、data和optimizer的参数可以通过加载yaml文件获取([vit_kno.yaml](https://gitee.com/mindspore/mindscience/blob/r0.5/MindEarth/applications/medium-range/koopman_vit/vit_kno.yaml))。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = load_yaml_config('vit_kno.yaml')\n",
    "config['model']['data_sink'] = True  # 是否使用data sink特性\n",
    "\n",
    "config['train']['distribute'] = False  # 是否执行分布式任务\n",
    "config['train']['amp_level'] = 'O2'  # 设置混合精度等级\n",
    "\n",
    "config['data']['num_workers'] = 1  # 设置并行计算的进程数量\n",
    "config['data']['grid_resolution'] = 1.4  # 设置气象分辨率参数\n",
    "\n",
    "config['optimizer']['epochs'] = 100  # 设置epoch数量\n",
    "config['optimizer']['finetune_epochs'] = 1  # 设置微调epoch数量\n",
    "config['optimizer']['warmup_epochs'] = 1  # 设置预热epoch的数量\n",
    "config['optimizer']['initial_lr'] = 0.0001  # 设置初始化学习率\n",
    "\n",
    "config['summary'][\"valid_frequency\"] = 10  # 设置验证的频率\n",
    "config['summary'][\"summary_dir\"] = './summary'  # 设置模型checkpoint的存储路径\n",
    "logger = create_logger(path=os.path.join(config['summary'][\"summary_dir\"], \"results.log\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建数据集\n",
    "\n",
    "在[dataset](https://download.mindspore.cn/mindscience/mindearth/dataset/WeatherBench_1.4_69/)路径下,下载正则化参数、训练数据集验证数据集到 `./dataset`目录。\n",
    "\n",
    "修改[vit_kno.yaml](https://gitee.com/mindspore/mindscience/blob/r0.5/MindEarth/applications/medium-range/koopman_vit/vit_kno.yaml)配置文件中的`root_dir`参数,该参数设置了数据集的路径。\n",
    "\n",
    "`./dataset`中的目录结构如下所示:\n",
    "\n",
    "``` markdown\n",
    ".\n",
    "├── statistic\n",
    "│   ├── mean.npy\n",
    "│   ├── mean_s.npy\n",
    "│   ├── std.npy\n",
    "│   └── std_s.npy\n",
    "├── train\n",
    "│   └── 2015\n",
    "├── train_static\n",
    "│   └── 2015\n",
    "├── train_surface\n",
    "│   └── 2015\n",
    "├── train_surface_static\n",
    "│   └── 2015\n",
    "├── valid\n",
    "│   └── 2016\n",
    "├── valid_static\n",
    "│   └── 2016\n",
    "├── valid_surface\n",
    "│   └── 2016\n",
    "├── valid_surface_static\n",
    "│   └── 2016\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型构建\n",
    "\n",
    "加载相关的数据参数和模型参数,并完成ViT-KNO模型构建。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_params = config[\"data\"]\n",
    "model_params = config[\"model\"]\n",
    "compute_type = mstype.float32\n",
    "\n",
    "model = ViTKNO(image_size=(data_params[\"h_size\"], data_params[\"w_size\"]),\n",
    "               in_channels=data_params[\"feature_dims\"],\n",
    "               out_channels=data_params[\"feature_dims\"],\n",
    "               patch_size=data_params[\"patch_size\"],\n",
    "               encoder_depths=model_params[\"encoder_depth\"],\n",
    "               encoder_embed_dims=model_params[\"encoder_embed_dim\"],\n",
    "               mlp_ratio=model_params[\"mlp_ratio\"],\n",
    "               dropout_rate=model_params[\"dropout_rate\"],\n",
    "               num_blocks=model_params[\"num_blocks\"],\n",
    "               high_freq=True,\n",
    "               encoder_network=model_params[\"encoder_network\"],\n",
    "               compute_dtype=compute_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 损失函数\n",
    "\n",
    "ViT-KNO使用多loss的训练方法,包括Prediction loss,Reconstruction loss,两者均基于均方误差(Mean Squared Error)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = Lploss()\n",
    "loss_net = CustomWithLossCell(model, loss_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练\n",
    "\n",
    "模型训练阶段继承了Trainer类,同时重写了get_dataset,get_callback,get_solver三个成员函数,以便于能在训练阶段执行测试验证;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
        "2023-09-07 02:22:28,644 - pretrain.py[line:211] - INFO: steps_per_epoch: 404\n"
    ]
    }
   ],
   "source": [
    "class ViTKNOEra5Data(Era5Data):\n",
    "    def _patch(self, x, img_size, patch_size, output_dims):\n",
    "        \"\"\" Partition the data into patches. \"\"\"\n",
    "        if self.run_mode == 'valid' or self.run_mode == 'test':\n",
    "            x = x.transpose(1, 0, 2, 3)\n",
    "        return x\n",
    "\n",
    "class ViTKNOTrainer(Trainer):\n",
    "    def __init__(self, config, model, loss_fn, logger):\n",
    "        super(ViTKNOTrainer, self).__init__(config, model, loss_fn, logger)\n",
    "        self.pred_cb = self.get_callback()\n",
    "\n",
    "    def get_dataset(self):\n",
    "        \"\"\"\n",
    "        Get train and valid dataset.\n",
    "\n",
    "        Returns:\n",
    "            Dataset, train dataset.\n",
    "            Dataset, valid dataset.\n",
    "        \"\"\"\n",
    "        train_dataset_generator = ViTKNOEra5Data(data_params=self.data_params, run_mode='train')\n",
    "        valid_dataset_generator = ViTKNOEra5Data(data_params=self.data_params, run_mode='valid')\n",
    "\n",
    "        train_dataset = Dataset(train_dataset_generator, distribute=self.train_params['distribute'],\n",
    "                                num_workers=self.data_params['num_workers'])\n",
    "        valid_dataset = Dataset(valid_dataset_generator, distribute=False, num_workers=self.data_params['num_workers'],\n",
    "                                shuffle=False)\n",
    "        train_dataset = train_dataset.create_dataset(self.data_params['batch_size'])\n",
    "        valid_dataset = valid_dataset.create_dataset(self.data_params['batch_size'])\n",
    "        return train_dataset, valid_dataset\n",
    "\n",
    "    def get_callback(self):\n",
    "        pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)\n",
    "        return pred_cb\n",
    "\n",
    "    def get_solver(self):\n",
    "        loss_scale = DynamicLossScaleManager()\n",
    "        solver = Model(self.loss_fn,\n",
    "                       optimizer=self.optimizer,\n",
    "                       loss_scale_manager=loss_scale,\n",
    "                       amp_level=self.train_params['amp_level']\n",
    "                       )\n",
    "        return solver\n",
    "\n",
    "\n",
    "trainer = ViTKNOTrainer(config, model, loss_net, logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 404, loss is 0.3572\n",
      "Train epoch time: 113870.065 ms, per step time: 281.857 ms\n",
      "epoch: 2 step: 404, loss is 0.2883\n",
      "Train epoch time: 38169.970 ms, per step time: 94.480 ms\n",
      "epoch: 3 step: 404, loss is 0.2776\n",
      "Train epoch time: 38192.446 ms, per step time: 94.536 ms\n",
      "...\n",
      "epoch: 98 step: 404, loss is 0.1279\n",
      "Train epoch time: 38254.867 ms, per step time: 94.690 ms\n",
      "epoch: 99 step: 404, loss is 0.1306\n",
      "Train epoch time: 38264.715 ms, per step time: 94.715 ms\n",
      "epoch: 100 step: 404, loss is 0.1301\n",
      "Train epoch time: 41886.174 ms, per step time: 103.679 ms\n",
      "2023-09-07 03:38:51,759 - forecast.py[line:209] - INFO: ================================Start Evaluation================================\n",
      "2023-09-07 03:39:57,551 - forecast.py[line:227] - INFO: test dataset size: 9\n",
      "2023-09-07 03:39:57,555 - forecast.py[line:177] - INFO: t = 6 hour: \n",
      "2023-09-07 03:39:57,555 - forecast.py[line:188] - INFO:  RMSE of Z500: 199.04419938873764, T2m: 2.44011585143782, T850: 1.45654734158296, U10: 1.636622237572019\n",
      "2023-09-07 03:39:57,556 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9898813962936401, T2m: 0.9677559733390808, T850: 0.9703396558761597, U10: 0.9609741568565369\n",
      "2023-09-07 03:39:57,557 - forecast.py[line:177] - INFO: t = 72 hour: \n",
      "2023-09-07 03:39:57,557 - forecast.py[line:188] - INFO:  RMSE of Z500: 925.158453845783, T2m: 4.638264378699863, T850: 4.385266743972255, U10: 4.761954010777025\n",
      "2023-09-07 03:39:57,558 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.7650538682937622, T2m: 0.8762193918228149, T850: 0.7014696598052979, U10: 0.6434637904167175\n",
      "2023-09-07 03:39:57,559 - forecast.py[line:177] - INFO: t = 120 hour: \n",
      "2023-09-07 03:39:57,559 - forecast.py[line:188] - INFO:  RMSE of Z500: 1105.3634480837272, T2m: 5.488261092294651, T850: 5.120214326468169, U10: 5.424460568523809\n",
      "2023-09-07 03:39:57,560 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.6540136337280273, T2m: 0.8196010589599609, T850: 0.5682352781295776, U10: 0.5316879749298096\n",
      "2023-09-07 03:39:57,561 - forecast.py[line:237] - INFO: ================================End Evaluation================================\n"
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型推理及可视化\n",
    "\n",
    "完成训练后,我们使用第100个ckpt进行推理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = load_checkpoint('./summary/ckpt/step_1/koopman_vit_1-100_404.ckpt')\n",
    "load_param_into_net(model, params)\n",
    "inference_module = InferenceModule(model, config, logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plt_data(pred, label, root_dir, index=0):\n",
    "    \"\"\" Visualize the forecast results \"\"\"\n",
    "    std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n",
    "    mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n",
    "    std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n",
    "    mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n",
    "\n",
    "    plt.figure(num='e_imshow', figsize=(100, 50), dpi=50)\n",
    "\n",
    "    plt.subplot(4, 3, 1)\n",
    "    plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth')  # Z500\n",
    "    plt.subplot(4, 3, 2)\n",
    "    plt_global_field_data(pred, 'Z500', std, mean, 'Pred')  # Z500\n",
    "    plt.subplot(4, 3, 3)\n",
    "    plt_global_field_data(label - pred, 'Z500', std, mean, 'Error')  # Z500\n",
    "\n",
    "    plt.subplot(4, 3, 4)\n",
    "    plt_global_field_data(label, 'T850', std, mean, 'Ground Truth')  # T850\n",
    "    plt.subplot(4, 3, 5)\n",
    "    plt_global_field_data(pred, 'T850', std, mean, 'Pred')  # T850\n",
    "    plt.subplot(4, 3, 6)\n",
    "    plt_global_field_data(label - pred, 'T850', std, mean, 'Error')  # T850\n",
    "\n",
    "    plt.subplot(4, 3, 7)\n",
    "    plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True)  # U10\n",
    "    plt.subplot(4, 3, 8)\n",
    "    plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True)  # U10\n",
    "    plt.subplot(4, 3, 9)\n",
    "    plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True)  # U10\n",
    "\n",
    "    plt.subplot(4, 3, 10)\n",
    "    plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True)  # T2M\n",
    "    plt.subplot(4, 3, 11)\n",
    "    plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True)  # T2M\n",
    "    plt.subplot(4, 3, 12)\n",
    "    plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True)  # T2M\n",
    "\n",
    "    plt.savefig(f'pred_result.png', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset_generator = Era5Data(data_params=config[\"data\"], run_mode='test')\n",
    "test_dataset = Dataset(test_dataset_generator, distribute=False,\n",
    "                       num_workers=config[\"data\"]['num_workers'], shuffle=False)\n",
    "test_dataset = test_dataset.create_dataset(config[\"data\"]['batch_size'])\n",
    "data = next(test_dataset.create_dict_iterator())\n",
    "inputs = data['inputs']\n",
    "labels = data['labels']\n",
    "pred_time_index = 0\n",
    "pred = inference_module.forecast(inputs)\n",
    "pred = pred[pred_time_index].asnumpy()\n",
    "ground_truth = labels[..., pred_time_index, :, :].asnumpy()\n",
    "plt_data(pred, ground_truth, config['data']['root_dir'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下述展示了第100个ckpt的真实值、预测值和他们之间的误差可视化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![plot result](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/docs/mindearth/docs/source_zh_cn/medium-range/images/vit_kno_result.png)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "16478c1492173c9a4f4847b8186328de7a4ca317afeafcd41bba7d71ba067560"
  },
  "kernelspec": {
   "display_name": "Python 3.9.16 64-bit ('lbk_ms10': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}