FourCastNet: Medium-range Global Weather Forecasting Based on FNO

DownloadNotebookDownloadCodeViewSource

Overview

FourCastNet (Fourier ForeCasting Neural Network) is a data-driven global weather forecast model developed by researchers from NVIDIA, Lawrence Berkeley National Laboratory, University of Michigan Ann Arbor, and Rice University. It provides medium-term forecasts of key global weather indicators with a resolution of 0.25°. Equivalent to a spatial resolution of approximately 30 km x 30 km near the equator and a global grid of 720 x 1440 pixels in size. Compared with the traditional NWP model, this model improves the prediction speed by 45000 times, generates a week’s weather forecast within 2 seconds, and achieves the prediction accuracy comparable to that of the most advanced numerical weather forecast model, ECMWF Integrated Forecast System (IFS). This is the first AI weather forecast model that can be directly compared to the IFS system.

This tutorial introduces the research background and technical path of FourCastNet, and shows how to train and fast infer the model through MindFlow. More information can be found in paper.

Technology Path

MindEarth solves the problem as follows:

  1. Training Data Construction.

  2. Model Construction.

  3. Loss function.

  4. Model Training.

  5. Model Evaluation and Visualization.

FourCastNet

In order to achieve high resolution prediction, FourCastNet uses AFNO model. The model network architecture is designed for high-resolution input, uses ViT as the backbone network, and incorporates Fourier Neural Operator (FNO) proposed by Zongyi Li et al. The model learns the mapping between function spaces so that series of nonlinear partial differential equations are solved.

The Vision Transformer (ViT) architecture and its variants have become the most advanced technology in computer vision over the past few years, exhibiting outstanding performance on many tasks. This performance is mainly attributed to the multi-head self-attention mechanism in the network, which makes the global modeling between each layer of features in the network. However, computation complexity of a model during training and inference increases quadratic as a quantity of tokens (or patches) increases, and model computation complexity increases explosively as input resolution increases.

The ingenuity of the AFNO model is that it converts the Spatial Mixing operation to the Fourier transform to mix the information of different tokens, transforms the features from the spatial domain to the frequency domain, and applies a globally learnable filter to the frequency domain features. The spatial mixing complexity is effectively reduced to O(NlogN), where N is the number of tokens.

The following figure shows the FourCastNet network architecture.

AFNO model

Model training consists of three steps:

  1. Pre-training: As shown in Figure (a) above, in the pre-training step, the AFNO model is trained in a supervised manner using the training dataset to learn the mapping from X(k) to X(k + 1).

  2. Fine tuning: As shown in Figure (b) above, the model first predicts X(k + 1) from X(k) and then uses X(k + 1) as input to predict X(k + 2). Then, the model is optimized using the sum of the two loss function values by calculating the loss function values from the predicted values of X(k + 1) and X(k + 2).

  3. Precipitation forecast: As shown in (c) above, the precipitation forecast is spliced by a separate model behind the backbone model. This method decouples the prediction task of precipitation from the basic meteorological factors. On the other hand, the trained precipitation model can also be used in combination with other prediction models (traditional NWP, etc.).

This tutorial mainly implements the model pre-training part.

[1]:
import os
import numpy as np
import matplotlib.pyplot as plt

from mindspore import context
from mindspore import load_checkpoint, load_param_into_net

from mindearth.utils import load_yaml_config, create_logger, plt_global_field_data
from mindearth.module import Trainer
from mindearth.data import Dataset, Era5Data
from mindearth import RelativeRMSELoss
from mindearth.cell import AFNONet

The following src can be downloaded in FourCastNet/src.

[2]:
from src.callback import EvaluateCallBack, InferenceModule

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0)

You can get parameters of model, data and optimizer from FourCastNet.yaml.

[4]:
config = load_yaml_config('FourCastNet.yaml')
config['model']['data_sink'] = True  # set the data sink feature

config['train']['distribute'] = False  # set the distribute feature
config['train']['amp_level'] = 'O2'  # set the level for mixed precision training

config['data']['num_workers'] = 1  # set the number of parallel workers
config['data']['grid_resolution'] = 1.4  # set the resolution for dataset

config['optimizer']['epochs'] = 100  # set the training epochs
config['optimizer']['finetune_epochs'] = 1  # set the the finetune epochs
config['optimizer']['warmup_epochs'] = 1  # set the warmup epochs
config['optimizer']['initial_lr'] = 0.0005  # set the initial learning rate

config['summary']["valid_frequency"] = 10  # set the frequency of validation
config['summary']["summary_dir"] = './summary'  # set the directory of model's checkpoint

logger = create_logger(path="results.log")

Training Data Construction

Download the statistic, training and validation dataset from dataset to ./dataset.

Modify the parameter of root_dir in the FourCastNet.yaml, which set the directory for dataset.

The ./dataset is hosted with the following directory structure:

.
├── statistic
│   ├── mean.npy
│   ├── mean_s.npy
│   ├── std.npy
│   └── std_s.npy
├── train
│   └── 2015
├── train_static
│   └── 2015
├── train_surface
│   └── 2015
├── train_surface_static
│   └── 2015
├── valid
│   └── 2016
├── valid_static
│   └── 2016
├── valid_surface
│   └── 2016
├── valid_surface_static
│   └── 2016

Model Construction

Load the data parameters and model parameters to the AFNONet model.

[5]:
data_params = config['data']
model_params = config['model']

model = AFNONet(image_size=(data_params['h_size'], data_params['w_size']),
                in_channels=data_params["feature_dims"],
                out_channels=data_params["feature_dims"],
                patch_size=data_params["patch_size"],
                encoder_depths=model_params["encoder_depths"],
                encoder_embed_dim=model_params["encoder_embed_dim"],
                mlp_ratio=model_params["mlp_ratio"],
                dropout_rate=model_params["dropout_rate"])

Loss Function

FourCastNet uses relative root mean squared error for model training.

[6]:
loss_fn = RelativeRMSELoss()

Model Training

In this tutorial, we inherite the Trainer and override the get_callback member function so that we can perform inference on the test dataset during the training process.

With MindSpore version >= 1.8.1, we can use the functional programming for training neural networks. MindEarth provide a training interface for model training.

[7]:
class FCNTrainer(Trainer):
    def __init__(self, config, model, loss_fn, logger):
        super(FCNTrainer, self).__init__(config, model, loss_fn, logger)
        self.pred_cb = self.get_callback()

    def get_callback(self):
        pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)
        return pred_cb

trainer = FCNTrainer(config, model, loss_fn, logger)
2023-08-21 07:34:55,267 - pretrain.py[line:211] - INFO: steps_per_epoch: 404
[8]:
trainer.train()
epoch: 1 step: 404, loss is 0.5348429
Train epoch time: 136480.515 ms, per step time: 337.823 ms
epoch: 2 step: 404, loss is 0.35937342
Train epoch time: 60902.627 ms, per step time: 150.749 ms
epoch: 3 step: 404, loss is 0.33921248
Train epoch time: 60737.844 ms, per step time: 150.341 ms
...
epoch: 98 step: 404, loss is 0.15447393
Train epoch time: 61055.706 ms, per step time: 151.128 ms
epoch: 99 step: 404, loss is 0.15696357
Train epoch time: 60850.156 ms, per step time: 150.619 ms
epoch: 100 step: 404, loss is 0.15654306
Train epoch time: 60944.369 ms, per step time: 150.852 ms
2023-09-07 04:27:02,837 - forecast.py[line:209] - INFO: ================================Start Evaluation================================
2023-09-07 04:28:25,277 - forecast.py[line:177] - INFO: t = 6 hour:
2023-09-07 04:28:25,277 - forecast.py[line:188] - INFO:  RMSE of Z500: 154.07894852240838, T2m: 2.0995438696856965, T850: 1.3081689948838815, U10: 1.527248748050362
2023-09-07 04:28:25,278 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9989880649296732, T2m: 0.9930711917863625, T850: 0.9954355203713009, U10: 0.9615764420500764
2023-09-07 04:28:25,279 - forecast.py[line:177] - INFO: t = 72 hour:
2023-09-07 04:28:25,279 - forecast.py[line:188] - INFO:  RMSE of Z500: 885.3778200063341, T2m: 4.586325958437852, T850: 4.2593739999338736, U10: 4.75655467109408
2023-09-07 04:28:25,280 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9598951919101183, T2m: 0.9658168304842388, T850: 0.9501612262744354, U10: 0.6175327930007481
2023-09-07 04:28:25,281 - forecast.py[line:177] - INFO: t = 120 hour:
2023-09-07 04:28:25,281 - forecast.py[line:188] - INFO:  RMSE of Z500: 1291.3199606908572, T2m: 6.734047767054735, T850: 5.6420206614200294, U10: 5.637643311177468
2023-09-07 04:28:25,282 - forecast.py[line:189] - INFO:  ACC  of Z500: 0.9150022892106006, T2m: 0.9294266102808937, T850: 0.9148957221265037, U10: 0.47971871343985495
2023-09-07 04:28:25,283 - forecast.py[line:237] - INFO: ================================End Evaluation================================

Model Evaluation and Visualization

After training, we use the 100th checkpoint for inference.

[9]:
pred_time_index = 0

params = load_checkpoint('./summary/ckpt/step_1/FourCastNet_1-100_404.ckpt')
load_param_into_net(model, params)
inference_module = InferenceModule(model, config, logger)
[10]:
def plt_data(pred, label, root_dir, index=0):
    """ Visualize the forecast results """
    std = np.load(os.path.join(root_dir, 'statistic/std.npy'))
    mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))
    std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))
    mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))
    pred, label = pred[index].asnumpy(), label.asnumpy()[..., index, :, :]
    plt.figure(num='e_imshow', figsize=(100, 50), dpi=100)

    plt.subplot(4, 3, 1)
    plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth')  # Z500
    plt.subplot(4, 3, 2)
    plt_global_field_data(pred, 'Z500', std, mean, 'Pred')  # Z500
    plt.subplot(4, 3, 3)
    plt_global_field_data(label - pred, 'Z500', std, mean, 'Error')  # Z500

    plt.subplot(4, 3, 4)
    plt_global_field_data(label, 'T850', std, mean, 'Ground Truth')  # T850
    plt.subplot(4, 3, 5)
    plt_global_field_data(pred, 'T850', std, mean, 'Pred')  # T850
    plt.subplot(4, 3, 6)
    plt_global_field_data(label - pred, 'T850', std, mean, 'Error')  # T850

    plt.subplot(4, 3, 7)
    plt_global_field_data(label, 'U10', std_s, mean_s,
                          'Ground Truth', is_surface=True)  # U10
    plt.subplot(4, 3, 8)
    plt_global_field_data(pred, 'U10', std_s, mean_s,
                          'Pred', is_surface=True)  # U10
    plt.subplot(4, 3, 9)
    plt_global_field_data(label - pred, 'U10', std_s,
                          mean_s, 'Error', is_surface=True)  # U10

    plt.subplot(4, 3, 10)
    plt_global_field_data(label, 'T2M', std_s, mean_s,
                          'Ground Truth', is_surface=True)  # T2M
    plt.subplot(4, 3, 11)
    plt_global_field_data(pred, 'T2M', std_s, mean_s,
                          'Pred', is_surface=True)  # T2M
    plt.subplot(4, 3, 12)
    plt_global_field_data(label - pred, 'T2M', std_s,
                          mean_s, 'Error', is_surface=True)  # T2M

    plt.savefig(f'pred_result.png', bbox_inches='tight')
    plt.show()

[11]:

test_dataset_generator = Era5Data(data_params=config["data"], run_mode='test') test_dataset = Dataset(test_dataset_generator, distribute=False, num_workers=config["data"]['num_workers'], shuffle=False) test_dataset = test_dataset.create_dataset(config["data"]['batch_size']) data = next(test_dataset.create_dict_iterator()) inputs = data['inputs'] labels = data['labels'] pred = inference_module.forecast(inputs) plt_data(pred, labels, config['data']['root_dir'])

The visualization of predictions by the 100th checkpoint, ground truth and their error is shown below.

plot result