{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ViT-KNO: Medium-range Global Weather Forecasting Based on Koopman\n", "\n", "[](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/medium-range/mindspore_vit_kno.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindearth/en/medium-range/mindspore_vit_kno.py) [](https://gitee.com/mindspore/docs/blob/master/docs/mindearth/docs/source_en/medium-range/vit_kno.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n", "\n", "Modern data weather prediction (Numerical Weather Prediction, NWP) can be traced back to 1920. Based on physical principles and integrating the achievements and experiences of several generations of meteorologists, NWP is the mainstream weather forecast method adopted by meteorological departments in various countries. Among them, the high resolution integrated system (IFS) model from the European Centre for Medium-Range Weather Forecasts (ECMWF) is the best.\n", "\n", "Until 2022, Nvidia has developed a Fourier neural network-based prediction model, FourCastNet, which can generate predictions of key global weather indicators at a resolution of 0.25°. This corresponds to a spatial resolution of about 30×30 km near the equator and a weighted grid size of 720×1440 pixels, consistent with the IFS system. This result allows for the first direct comparison of AI weather models with the traditional physical model IFS. For more information, please refer to: [\"FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators\"](https://arxiv.org/pdf/2202.11214.pdf).\n", "\n", "However, FourCastnet, a prediction model based on Fourier Neural Operator (FNO), is not accurate and interpretable in predicting medium-term and long-term weather. ViT-KNO makes full use of Vision Transformer structure and Koopman theory to learn Koopman Operator to predict nonlinear dynamic systems. By embedding complex dynamics into linear structures to constrain the reconstruction process, ViT-KNO can capture complex nonlinear behaviors while maintaining model lightweight and computational efficiency. ViT-KNO has clear mathematical theory support, and overcomes the problems of mathematical and physical explainability and lack of theoretical basis of similar methods. For more information, refer to: [\"KoopmanLab: machine learning for solving complex physics equations\"](https://arxiv.org/pdf/2301.01104.pdf).\n", "\n", "## Technology Path\n", "\n", "MindSpore solves the problem as follows:\n", "\n", "1. Training Data Construction.\n", "2. Model Construction.\n", "3. Loss function.\n", "4. Model Training.\n", "5. Model Evaluation and Visualization.\n", "\n", "## ViT-KNO\n", "\n", "The following figure shows the ViT-KNO model architecture, which consists of two branches. The upstream branch is responsible for result prediction and consists of the encoder module, Koopman Layer module, and decoder module. The Koopman Layer module is shown in the dotted box and can be stacked repeatedly. The downstream branch consists of the encoder and decoder modules, which reconstruct input information.\n", "\n", "\n", "\n", "The model training process is as follows:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from mindspore import context, Model, load_checkpoint, load_param_into_net\n", "from mindspore import dtype as mstype\n", "from mindspore.amp import DynamicLossScaleManager\n", "\n", "from mindearth.cell import ViTKNO\n", "from mindearth.utils import load_yaml_config, create_logger, plt_global_field_data\n", "from mindearth.data import Dataset, Era5Data\n", "from mindearth.module import Trainer" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The following `src` can be downloaded in [ViT-KNO/src](https://gitee.com/mindspore/mindscience/tree/master/MindEarth/applications/medium-range/koopman_vit/src)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from src.callback import EvaluateCallBack, InferenceModule, Lploss, CustomWithLossCell\n", "\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can get parameters of model, data and optimizer from [vit_kno.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindEarth/applications/medium-range/koopman_vit/configs/vit_kno_1.4.yaml)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "config = load_yaml_config('./vit_kno.yaml')\n", "config['model']['data_sink'] = True # set the data sink feature\n", "\n", "config['train']['distribute'] = False # set the distribute feature\n", "config['train']['amp_level'] = 'O2' # set the level for mixed precision training\n", "\n", "config['data']['num_workers'] = 1 # set the number of parallel workers\n", "config['data']['grid_resolution'] = 1.4 # set the resolution for dataset\n", "\n", "config['optimizer']['epochs'] = 100 # set the training epochs\n", "config['optimizer']['finetune_epochs'] = 1 # set the the finetune epochs\n", "config['optimizer']['warmup_epochs'] = 1 # set the warmup epochs\n", "config['optimizer']['initial_lr'] = 0.0001 # set the initial learning rate\n", "\n", "config['summary'][\"valid_frequency\"] = 10 # set the frequency of validation\n", "config['summary'][\"summary_dir\"] = './summary' # set the directory of model's checkpoint\n", "\n", "logger = create_logger(path=os.path.join(config['summary'][\"summary_dir\"], \"results.log\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Data Construction\n", "\n", "Download the statistic, training and validation dataset from [dataset](https://download.mindspore.cn/mindscience/mindearth/dataset/WeatherBench_1.4_69/) to `./dataset`.\n", "\n", "Modify the parameter of `root_dir` in the [vit_kno.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindEarth/applications/medium-range/koopman_vit/configs/vit_kno_1.4.yaml), which set the directory for dataset.\n", "\n", "The `./dataset` is hosted with the following directory structure:\n", "\n", "```markdown\n", ".\n", "├── statistic\n", "│ ├── mean.npy\n", "│ ├── mean_s.npy\n", "│ ├── std.npy\n", "│ └── std_s.npy\n", "├── train\n", "│ └── 2015\n", "├── train_static\n", "│ └── 2015\n", "├── train_surface\n", "│ └── 2015\n", "├── train_surface_static\n", "│ └── 2015\n", "├── valid\n", "│ └── 2016\n", "├── valid_static\n", "│ └── 2016\n", "├── valid_surface\n", "│ └── 2016\n", "├── valid_surface_static\n", "│ └── 2016\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Construction\n", "\n", "Load the data parameters and model parameters to the ViTKNO model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data_params = config[\"data\"]\n", "model_params = config[\"model\"]\n", "compute_type = mstype.float32\n", "\n", "model = ViTKNO(image_size=(data_params[\"h_size\"], data_params[\"w_size\"]),\n", " in_channels=data_params[\"feature_dims\"],\n", " out_channels=data_params[\"feature_dims\"],\n", " patch_size=data_params[\"patch_size\"],\n", " encoder_depths=model_params[\"encoder_depth\"],\n", " encoder_embed_dims=model_params[\"encoder_embed_dim\"],\n", " mlp_ratio=model_params[\"mlp_ratio\"],\n", " dropout_rate=model_params[\"dropout_rate\"],\n", " num_blocks=model_params[\"num_blocks\"],\n", " high_freq=True,\n", " encoder_network=model_params[\"encoder_network\"],\n", " compute_dtype=compute_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loss Function\n", "\n", "ViT-KNO uses multi-loss training methods, including Prediction loss and Reconstruction loss, both based on mean squared error." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "loss_fn = Lploss()\n", "loss_net = CustomWithLossCell(model, loss_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Training\n", "\n", "In this tutorial, we inherite the Trainer and override the get_dataset, get_callback and get_solver member functions so that we can perform inference on the test dataset during the training process." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2023-09-07 02:22:28,644 - pretrain.py[line:211] - INFO: steps_per_epoch: 404\n" ] } ], "source": [ "class ViTKNOEra5Data(Era5Data):\n", " def _patch(self, x, img_size, patch_size, output_dims):\n", " \"\"\" Partition the data into patches. \"\"\"\n", " if self.run_mode == 'valid' or self.run_mode == 'test':\n", " x = x.transpose(1, 0, 2, 3)\n", " return x\n", "\n", "\n", "class ViTKNOTrainer(Trainer):\n", " def __init__(self, config, model, loss_fn, logger):\n", " super(ViTKNOTrainer, self).__init__(config, model, loss_fn, logger)\n", " self.pred_cb = self.get_callback()\n", "\n", " def get_dataset(self):\n", " \"\"\"\n", " Get train and valid dataset.\n", "\n", " Returns:\n", " Dataset, train dataset.\n", " Dataset, valid dataset.\n", " \"\"\"\n", " train_dataset_generator = ViTKNOEra5Data(data_params=self.data_params, run_mode='train')\n", " valid_dataset_generator = ViTKNOEra5Data(data_params=self.data_params, run_mode='valid')\n", "\n", " train_dataset = Dataset(train_dataset_generator, distribute=self.train_params['distribute'],\n", " num_workers=self.data_params['num_workers'])\n", " valid_dataset = Dataset(valid_dataset_generator, distribute=False, num_workers=self.data_params['num_workers'],\n", " shuffle=False)\n", " train_dataset = train_dataset.create_dataset(self.data_params['batch_size'])\n", " valid_dataset = valid_dataset.create_dataset(self.data_params['batch_size'])\n", " return train_dataset, valid_dataset\n", "\n", " def get_callback(self):\n", " pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)\n", " return pred_cb\n", "\n", " def get_solver(self):\n", " loss_scale = DynamicLossScaleManager()\n", " solver = Model(self.loss_fn,\n", " optimizer=self.optimizer,\n", " loss_scale_manager=loss_scale,\n", " amp_level=self.train_params['amp_level']\n", " )\n", " return solver\n", "\n", "\n", "trainer = ViTKNOTrainer(config, model, loss_net, logger)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 404, loss is 0.3572\n", "Train epoch time: 113870.065 ms, per step time: 281.857 ms\n", "epoch: 2 step: 404, loss is 0.2883\n", "Train epoch time: 38169.970 ms, per step time: 94.480 ms\n", "epoch: 3 step: 404, loss is 0.2776\n", "Train epoch time: 38192.446 ms, per step time: 94.536 ms\n", "...\n", "epoch: 98 step: 404, loss is 0.1279\n", "Train epoch time: 38254.867 ms, per step time: 94.690 ms\n", "epoch: 99 step: 404, loss is 0.1306\n", "Train epoch time: 38264.715 ms, per step time: 94.715 ms\n", "epoch: 100 step: 404, loss is 0.1301\n", "Train epoch time: 41886.174 ms, per step time: 103.679 ms\n", "2023-09-07 03:38:51,759 - forecast.py[line:209] - INFO: ================================Start Evaluation================================\n", "2023-09-07 03:39:57,551 - forecast.py[line:227] - INFO: test dataset size: 9\n", "2023-09-07 03:39:57,555 - forecast.py[line:177] - INFO: t = 6 hour: \n", "2023-09-07 03:39:57,555 - forecast.py[line:188] - INFO: RMSE of Z500: 199.04419938873764, T2m: 2.44011585143782, T850: 1.45654734158296, U10: 1.636622237572019\n", "2023-09-07 03:39:57,556 - forecast.py[line:189] - INFO: ACC of Z500: 0.9898813962936401, T2m: 0.9677559733390808, T850: 0.9703396558761597, U10: 0.9609741568565369\n", "2023-09-07 03:39:57,557 - forecast.py[line:177] - INFO: t = 72 hour: \n", "2023-09-07 03:39:57,557 - forecast.py[line:188] - INFO: RMSE of Z500: 925.158453845783, T2m: 4.638264378699863, T850: 4.385266743972255, U10: 4.761954010777025\n", "2023-09-07 03:39:57,558 - forecast.py[line:189] - INFO: ACC of Z500: 0.7650538682937622, T2m: 0.8762193918228149, T850: 0.7014696598052979, U10: 0.6434637904167175\n", "2023-09-07 03:39:57,559 - forecast.py[line:177] - INFO: t = 120 hour: \n", "2023-09-07 03:39:57,559 - forecast.py[line:188] - INFO: RMSE of Z500: 1105.3634480837272, T2m: 5.488261092294651, T850: 5.120214326468169, U10: 5.424460568523809\n", "2023-09-07 03:39:57,560 - forecast.py[line:189] - INFO: ACC of Z500: 0.6540136337280273, T2m: 0.8196010589599609, T850: 0.5682352781295776, U10: 0.5316879749298096\n", "2023-09-07 03:39:57,561 - forecast.py[line:237] - INFO: ================================End Evaluation================================\n" ] } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Evaluation and Visualization\n", "\n", "After training, we use the 100th checkpoint for inference." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "params = load_checkpoint('./summary/ckpt/step_1/koopman_vit_2-100_404.ckpt')\n", "load_param_into_net(model, params)\n", "inference_module = InferenceModule(model, config, logger)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def plt_data(pred, label, root_dir, index=0):\n", " \"\"\" Visualize the forecast results \"\"\"\n", " std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n", " mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n", " std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n", " mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n", "\n", " plt.figure(num='e_imshow', figsize=(100, 50), dpi=50)\n", "\n", " plt.subplot(4, 3, 1)\n", " plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth') # Z500\n", " plt.subplot(4, 3, 2)\n", " plt_global_field_data(pred, 'Z500', std, mean, 'Pred') # Z500\n", " plt.subplot(4, 3, 3)\n", " plt_global_field_data(label - pred, 'Z500', std, mean, 'Error') # Z500\n", "\n", " plt.subplot(4, 3, 4)\n", " plt_global_field_data(label, 'T850', std, mean, 'Ground Truth') # T850\n", " plt.subplot(4, 3, 5)\n", " plt_global_field_data(pred, 'T850', std, mean, 'Pred') # T850\n", " plt.subplot(4, 3, 6)\n", " plt_global_field_data(label - pred, 'T850', std, mean, 'Error') # T850\n", "\n", " plt.subplot(4, 3, 7)\n", " plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True) # U10\n", " plt.subplot(4, 3, 8)\n", " plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True) # U10\n", " plt.subplot(4, 3, 9)\n", " plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True) # U10\n", "\n", " plt.subplot(4, 3, 10)\n", " plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True) # T2M\n", " plt.subplot(4, 3, 11)\n", " plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True) # T2M\n", " plt.subplot(4, 3, 12)\n", " plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True) # T2M\n", "\n", " plt.savefig(f'pred_result.png', bbox_inches='tight')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "test_dataset_generator = Era5Data(data_params=config[\"data\"], run_mode='test')\n", "test_dataset = Dataset(test_dataset_generator, distribute=False,\n", " num_workers=config[\"data\"]['num_workers'], shuffle=False)\n", "test_dataset = test_dataset.create_dataset(config[\"data\"]['batch_size'])\n", "data = next(test_dataset.create_dict_iterator())\n", "inputs = data['inputs']\n", "labels = data['labels']\n", "pred_time_index = 0\n", "pred = inference_module.forecast(inputs)\n", "pred = pred[pred_time_index].asnumpy()\n", "ground_truth = labels[..., pred_time_index, :, :].asnumpy()\n", "plt_data(pred, ground_truth, config['data']['root_dir'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The visualization of predictions by the 100th checkpoint, ground truth and their error is shown below.\n", "\n", "" ] } ], "metadata": { "interpreter": { "hash": "16478c1492173c9a4f4847b8186328de7a4ca317afeafcd41bba7d71ba067560" }, "kernelspec": { "display_name": "Python 3.9.16 64-bit ('lbk_ms10': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }