{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# FourCastNet: Medium-range Global Weather Forecasting Based on FNO\n", "\n", "[](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth_r0.2/en/medium-range/mindspore_FourCastNet.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindearth_r0.2/en/medium-range/mindspore_FourCastNet.py) [](https://gitee.com/mindspore/docs/blob/r2.2/docs/mindearth_r0.2/docs/source_en/medium-range/FourCastNet.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Overview\n", "\n", "FourCastNet (Fourier ForeCasting Neural Network) is a data-driven global weather forecast model developed by researchers from NVIDIA, Lawrence Berkeley National Laboratory, University of Michigan Ann Arbor, and Rice University. It provides medium-term forecasts of key global weather indicators with a resolution of 0.25°. Equivalent to a spatial resolution of approximately 30 km x 30 km near the equator and a global grid of 720 x 1440 pixels in size. Compared with the traditional NWP model, this model improves the prediction speed by 45000 times, generates a week's weather forecast within 2 seconds, and achieves the prediction accuracy comparable to that of the most advanced numerical weather forecast model, ECMWF Integrated Forecast System (IFS). This is the first AI weather forecast model that can be directly compared to the IFS system.\n", "\n", "This tutorial introduces the research background and technical path of FourCastNet, and shows how to train and fast infer the model through MindFlow. More information can be found in [paper](https://arxiv.org/abs/2202.11214)." ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Technology Path\n", "\n", "MindEarth solves the problem as follows:\n", "\n", "1. Training Data Construction.\n", "2. Model Construction.\n", "3. Loss function.\n", "4. Model Training.\n", "5. Model Evaluation and Visualization.\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## FourCastNet\n", "\n", "In order to achieve high resolution prediction, FourCastNet uses AFNO model. The model network architecture is designed for high-resolution input, uses ViT as the backbone network, and incorporates Fourier Neural Operator (FNO) proposed by Zongyi Li et al. The model learns the mapping between function spaces so that series of nonlinear partial differential equations are solved.\n", "\n", "The Vision Transformer (ViT) architecture and its variants have become the most advanced technology in computer vision over the past few years, exhibiting outstanding performance on many tasks. This performance is mainly attributed to the multi-head self-attention mechanism in the network, which makes the global modeling between each layer of features in the network. However, computation complexity of a model during training and inference increases quadratic as a quantity of tokens (or patches) increases, and model computation complexity increases explosively as input resolution increases.\n", "\n", "The ingenuity of the AFNO model is that it converts the Spatial Mixing operation to the Fourier transform to mix the information of different tokens, transforms the features from the spatial domain to the frequency domain, and applies a globally learnable filter to the frequency domain features. The spatial mixing complexity is effectively reduced to O(NlogN), where N is the number of tokens.\n", "\n", "The following figure shows the FourCastNet network architecture.\n", "\n", "\n", "\n", "Model training consists of three steps:\n", "\n", "1. Pre-training: As shown in Figure (a) above, in the pre-training step, the AFNO model is trained in a supervised manner using the training dataset to learn the mapping from X(k) to X(k + 1).\n", "\n", "2. Fine tuning: As shown in Figure (b) above, the model first predicts X(k + 1) from X(k) and then uses X(k + 1) as input to predict X(k + 2). Then, the model is optimized using the sum of the two loss function values by calculating the loss function values from the predicted values of X(k + 1) and X(k + 2).\n", "\n", "3. Precipitation forecast: As shown in (c) above, the precipitation forecast is spliced by a separate model behind the backbone model. This method decouples the prediction task of precipitation from the basic meteorological factors. On the other hand, the trained precipitation model can also be used in combination with other prediction models (traditional NWP, etc.).\n", "\n", "This tutorial mainly implements the model pre-training part." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from mindspore import context\n", "from mindspore import load_checkpoint, load_param_into_net\n", "\n", "from mindearth.utils import load_yaml_config, create_logger, plt_global_field_data\n", "from mindearth.module import Trainer\n", "from mindearth.data import Dataset, Era5Data\n", "from mindearth import RelativeRMSELoss\n", "from mindearth.cell import AFNONet" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The following `src` can be downloaded in [FourCastNet/src](https://gitee.com/mindspore/mindscience/tree/r0.6/MindEarth/applications/medium-range/fourcastnet/src)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from src.callback import EvaluateCallBack, InferenceModule\n", "\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can get parameters of model, data and optimizer from [FourCastNet.yaml](https://gitee.com/mindspore/mindscience/blob/r0.6/MindEarth/applications/medium-range/fourcastnet/FourCastNet.yaml)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "config = load_yaml_config('FourCastNet.yaml')\n", "config['model']['data_sink'] = True # set the data sink feature\n", "\n", "config['train']['distribute'] = False # set the distribute feature\n", "config['train']['amp_level'] = 'O2' # set the level for mixed precision training\n", "\n", "config['data']['num_workers'] = 1 # set the number of parallel workers\n", "config['data']['grid_resolution'] = 1.4 # set the resolution for dataset\n", "\n", "config['optimizer']['epochs'] = 100 # set the training epochs\n", "config['optimizer']['finetune_epochs'] = 1 # set the the finetune epochs\n", "config['optimizer']['warmup_epochs'] = 1 # set the warmup epochs\n", "config['optimizer']['initial_lr'] = 0.0005 # set the initial learning rate\n", "\n", "config['summary'][\"valid_frequency\"] = 10 # set the frequency of validation\n", "config['summary'][\"summary_dir\"] = './summary' # set the directory of model's checkpoint\n", "\n", "logger = create_logger(path=\"results.log\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Data Construction\n", "\n", "Download the statistic, training and validation dataset from [dataset](https://download.mindspore.cn/mindscience/mindearth/dataset/WeatherBench_1.4_69/) to `./dataset`.\n", "\n", "Modify the parameter of `root_dir` in the [FourCastNet.yaml](https://gitee.com/mindspore/mindscience/blob/r0.6/MindEarth/applications/medium-range/fourcastnet/FourCastNet.yaml), which sets the directory for dataset.\n", "\n", "The `./dataset` is hosted with the following directory structure:\n", "\n", "```markdown\n", ".\n", "├── statistic\n", "│ ├── mean.npy\n", "│ ├── mean_s.npy\n", "│ ├── std.npy\n", "│ └── std_s.npy\n", "├── train\n", "│ └── 2015\n", "├── train_static\n", "│ └── 2015\n", "├── train_surface\n", "│ └── 2015\n", "├── train_surface_static\n", "│ └── 2015\n", "├── valid\n", "│ └── 2016\n", "├── valid_static\n", "│ └── 2016\n", "├── valid_surface\n", "│ └── 2016\n", "├── valid_surface_static\n", "│ └── 2016\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Model Construction" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Load the data parameters and model parameters to the AFNONet model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "data_params = config['data']\n", "model_params = config['model']\n", "\n", "model = AFNONet(image_size=(data_params['h_size'], data_params['w_size']),\n", " in_channels=data_params[\"feature_dims\"],\n", " out_channels=data_params[\"feature_dims\"],\n", " patch_size=data_params[\"patch_size\"],\n", " encoder_depths=model_params[\"encoder_depths\"],\n", " encoder_embed_dim=model_params[\"encoder_embed_dim\"],\n", " mlp_ratio=model_params[\"mlp_ratio\"],\n", " dropout_rate=model_params[\"dropout_rate\"])" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Loss Function\n", "\n", "FourCastNet uses relative root mean squared error for model training." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "loss_fn = RelativeRMSELoss()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Model Training\n", "\n", "In this tutorial, we inherite the Trainer and override the get_callback member function so that we can perform inference on the test dataset during the training process.\n", "\n", "With MindSpore version >= 1.8.1, we can use the functional programming for training neural networks. MindSpore Earth provides a training interface for model training." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2023-09-07 02:26:03,143 - pretrain.py[line:211] - INFO: steps_per_epoch: 404\n" ] } ], "source": [ "class FCNTrainer(Trainer):\n", " def __init__(self, config, model, loss_fn, logger):\n", " super(FCNTrainer, self).__init__(config, model, loss_fn, logger)\n", " self.pred_cb = self.get_callback()\n", "\n", " def get_callback(self):\n", " pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)\n", " return pred_cb\n", "\n", "trainer = FCNTrainer(config, model, loss_fn, logger)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 404, loss is 0.5348429\n", "Train epoch time: 136480.515 ms, per step time: 337.823 ms\n", "epoch: 2 step: 404, loss is 0.35937342\n", "Train epoch time: 60902.627 ms, per step time: 150.749 ms\n", "epoch: 3 step: 404, loss is 0.33921248\n", "Train epoch time: 60737.844 ms, per step time: 150.341 ms\n", "...\n", "epoch: 98 step: 404, loss is 0.15447393\n", "Train epoch time: 61055.706 ms, per step time: 151.128 ms\n", "epoch: 99 step: 404, loss is 0.15696357\n", "Train epoch time: 60850.156 ms, per step time: 150.619 ms\n", "epoch: 100 step: 404, loss is 0.15654306\n", "Train epoch time: 60944.369 ms, per step time: 150.852 ms\n", "2023-09-07 04:27:02,837 - forecast.py[line:209] - INFO: ================================Start Evaluation================================\n", "2023-09-07 04:28:25,277 - forecast.py[line:177] - INFO: t = 6 hour: \n", "2023-09-07 04:28:25,277 - forecast.py[line:188] - INFO: RMSE of Z500: 154.07894852240838, T2m: 2.0995438696856965, T850: 1.3081689948838815, U10: 1.527248748050362\n", "2023-09-07 04:28:25,278 - forecast.py[line:189] - INFO: ACC of Z500: 0.9989880649296732, T2m: 0.9930711917863625, T850: 0.9954355203713009, U10: 0.9615764420500764\n", "2023-09-07 04:28:25,279 - forecast.py[line:177] - INFO: t = 72 hour: \n", "2023-09-07 04:28:25,279 - forecast.py[line:188] - INFO: RMSE of Z500: 885.3778200063341, T2m: 4.586325958437852, T850: 4.2593739999338736, U10: 4.75655467109408\n", "2023-09-07 04:28:25,280 - forecast.py[line:189] - INFO: ACC of Z500: 0.9598951919101183, T2m: 0.9658168304842388, T850: 0.9501612262744354, U10: 0.6175327930007481\n", "2023-09-07 04:28:25,281 - forecast.py[line:177] - INFO: t = 120 hour: \n", "2023-09-07 04:28:25,281 - forecast.py[line:188] - INFO: RMSE of Z500: 1291.3199606908572, T2m: 6.734047767054735, T850: 5.6420206614200294, U10: 5.637643311177468\n", "2023-09-07 04:28:25,282 - forecast.py[line:189] - INFO: ACC of Z500: 0.9150022892106006, T2m: 0.9294266102808937, T850: 0.9148957221265037, U10: 0.47971871343985495\n", "2023-09-07 04:28:25,283 - forecast.py[line:237] - INFO: ================================End Evaluation================================\n" ] } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Model Evaluation and Visualization\n", "\n", "After training, we use the 100th checkpoint for inference." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "pred_time_index = 0\n", "\n", "params = load_checkpoint('./summary/ckpt/step_1/FourCastNet_1-100_404.ckpt')\n", "load_param_into_net(model, params)\n", "inference_module = InferenceModule(model, config, logger)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def plt_data(pred, label, root_dir, index=0):\n", " \"\"\" Visualize the forecast results \"\"\"\n", " std = np.load(os.path.join(root_dir, 'statistic/std.npy'))\n", " mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))\n", " std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))\n", " mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))\n", " pred, label = pred[index].asnumpy(), label.asnumpy()[..., index, :, :]\n", " plt.figure(num='e_imshow', figsize=(100, 50), dpi=100)\n", "\n", " plt.subplot(4, 3, 1)\n", " plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth') # Z500\n", " plt.subplot(4, 3, 2)\n", " plt_global_field_data(pred, 'Z500', std, mean, 'Pred') # Z500\n", " plt.subplot(4, 3, 3)\n", " plt_global_field_data(label - pred, 'Z500', std, mean, 'Error') # Z500\n", "\n", " plt.subplot(4, 3, 4)\n", " plt_global_field_data(label, 'T850', std, mean, 'Ground Truth') # T850\n", " plt.subplot(4, 3, 5)\n", " plt_global_field_data(pred, 'T850', std, mean, 'Pred') # T850\n", " plt.subplot(4, 3, 6)\n", " plt_global_field_data(label - pred, 'T850', std, mean, 'Error') # T850\n", "\n", " plt.subplot(4, 3, 7)\n", " plt_global_field_data(label, 'U10', std_s, mean_s,\n", " 'Ground Truth', is_surface=True) # U10\n", " plt.subplot(4, 3, 8)\n", " plt_global_field_data(pred, 'U10', std_s, mean_s,\n", " 'Pred', is_surface=True) # U10\n", " plt.subplot(4, 3, 9)\n", " plt_global_field_data(label - pred, 'U10', std_s,\n", " mean_s, 'Error', is_surface=True) # U10\n", "\n", " plt.subplot(4, 3, 10)\n", " plt_global_field_data(label, 'T2M', std_s, mean_s,\n", " 'Ground Truth', is_surface=True) # T2M\n", " plt.subplot(4, 3, 11)\n", " plt_global_field_data(pred, 'T2M', std_s, mean_s,\n", " 'Pred', is_surface=True) # T2M\n", " plt.subplot(4, 3, 12)\n", " plt_global_field_data(label - pred, 'T2M', std_s,\n", " mean_s, 'Error', is_surface=True) # T2M\n", "\n", " plt.savefig(f'pred_result.png', bbox_inches='tight')\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "\n", "test_dataset_generator = Era5Data(data_params=config[\"data\"], run_mode='test')\n", "test_dataset = Dataset(test_dataset_generator, distribute=False,\n", " num_workers=config[\"data\"]['num_workers'], shuffle=False)\n", "test_dataset = test_dataset.create_dataset(config[\"data\"]['batch_size'])\n", "data = next(test_dataset.create_dict_iterator())\n", "inputs = data['inputs']\n", "labels = data['labels']\n", "pred = inference_module.forecast(inputs)\n", "plt_data(pred, labels, config['data']['root_dir'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The visualization of predictions by the 100th checkpoint, ground truth and their error is shown below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] } ], "metadata": { "interpreter": { "hash": "16478c1492173c9a4f4847b8186328de7a4ca317afeafcd41bba7d71ba067560" }, "kernelspec": { "display_name": "Python 3.9.16 64-bit ('lbk_ms10': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }