FourCastNet: Medium-range Global Weather Forecasting Based on FNO
Overview
FourCastNet (Fourier ForeCasting Neural Network) is a data-driven global weather forecast model developed by researchers from NVIDIA, Lawrence Berkeley National Laboratory, University of Michigan Ann Arbor, and Rice University. It provides medium-term forecasts of key global weather indicators with a resolution of 0.25°. Equivalent to a spatial resolution of approximately 30 km x 30 km near the equator and a global grid of 720 x 1440 pixels in size. Compared with the traditional NWP model, this model improves the prediction speed by 45000 times, generates a week’s weather forecast within 2 seconds, and achieves the prediction accuracy comparable to that of the most advanced numerical weather forecast model, ECMWF Integrated Forecast System (IFS). This is the first AI weather forecast model that can be directly compared to the IFS system.
This tutorial introduces the research background and technical path of FourCastNet, and shows how to train and fast infer the model through MindFlow. More information can be found in paper.
Technology Path
MindEarth solves the problem as follows:
Training Data Construction.
Model Construction.
Loss function.
Model Training.
Model Evaluation and Visualization.
FourCastNet
In order to achieve high resolution prediction, FourCastNet uses AFNO model. The model network architecture is designed for high-resolution input, uses ViT as the backbone network, and incorporates Fourier Neural Operator (FNO) proposed by Zongyi Li et al. The model learns the mapping between function spaces so that series of nonlinear partial differential equations are solved.
The Vision Transformer (ViT) architecture and its variants have become the most advanced technology in computer vision over the past few years, exhibiting outstanding performance on many tasks. This performance is mainly attributed to the multi-head self-attention mechanism in the network, which makes the global modeling between each layer of features in the network. However, computation complexity of a model during training and inference increases quadratic as a quantity of tokens (or patches) increases, and model computation complexity increases explosively as input resolution increases.
The ingenuity of the AFNO model is that it converts the Spatial Mixing operation to the Fourier transform to mix the information of different tokens, transforms the features from the spatial domain to the frequency domain, and applies a globally learnable filter to the frequency domain features. The spatial mixing complexity is effectively reduced to O(NlogN), where N is the number of tokens.
The following figure shows the FourCastNet network architecture.
Model training consists of three steps:
Pre-training: As shown in Figure (a) above, in the pre-training step, the AFNO model is trained in a supervised manner using the training dataset to learn the mapping from X(k) to X(k + 1).
Fine tuning: As shown in Figure (b) above, the model first predicts X(k + 1) from X(k) and then uses X(k + 1) as input to predict X(k + 2). Then, the model is optimized using the sum of the two loss function values by calculating the loss function values from the predicted values of X(k + 1) and X(k + 2).
Precipitation forecast: As shown in (c) above, the precipitation forecast is spliced by a separate model behind the backbone model. This method decouples the prediction task of precipitation from the basic meteorological factors. On the other hand, the trained precipitation model can also be used in combination with other prediction models (traditional NWP, etc.).
This tutorial mainly implements the model pre-training part.
[1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from mindspore import context
from mindspore import load_checkpoint, load_param_into_net
from mindearth.utils import load_yaml_config, create_logger, plt_global_field_data
from mindearth.module import Trainer
from mindearth.data import Dataset, Era5Data
from mindearth import RelativeRMSELoss
from mindearth.cell import AFNONet
The following src
can be downloaded in FourCastNet/src.
[2]:
from src.callback import EvaluateCallBack, InferenceModule
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0)
You can get parameters of model, data and optimizer from FourCastNet.yaml.
[4]:
config = load_yaml_config('FourCastNet.yaml')
config['model']['data_sink'] = True # set the data sink feature
config['train']['distribute'] = False # set the distribute feature
config['train']['amp_level'] = 'O2' # set the level for mixed precision training
config['data']['num_workers'] = 1 # set the number of parallel workers
config['data']['grid_resolution'] = 1.4 # set the resolution for dataset
config['optimizer']['epochs'] = 100 # set the training epochs
config['optimizer']['finetune_epochs'] = 1 # set the the finetune epochs
config['optimizer']['warmup_epochs'] = 1 # set the warmup epochs
config['optimizer']['initial_lr'] = 0.0005 # set the initial learning rate
config['summary']["valid_frequency"] = 10 # set the frequency of validation
config['summary']["summary_dir"] = './summary' # set the directory of model's checkpoint
logger = create_logger(path="results.log")
Training Data Construction
Download the statistic, training and validation dataset from dataset to ./dataset
.
Modify the parameter of root_dir
in the FourCastNet.yaml, which set the directory for dataset.
The ./dataset
is hosted with the following directory structure:
.
├── statistic
│ ├── mean.npy
│ ├── mean_s.npy
│ ├── std.npy
│ └── std_s.npy
├── train
│ └── 2015
├── train_static
│ └── 2015
├── train_surface
│ └── 2015
├── train_surface_static
│ └── 2015
├── valid
│ └── 2016
├── valid_static
│ └── 2016
├── valid_surface
│ └── 2016
├── valid_surface_static
│ └── 2016
Model Construction
Load the data parameters and model parameters to the AFNONet model.
[5]:
data_params = config['data']
model_params = config['model']
model = AFNONet(image_size=(data_params['h_size'], data_params['w_size']),
in_channels=data_params["feature_dims"],
out_channels=data_params["feature_dims"],
patch_size=data_params["patch_size"],
encoder_depths=model_params["encoder_depths"],
encoder_embed_dim=model_params["encoder_embed_dim"],
mlp_ratio=model_params["mlp_ratio"],
dropout_rate=model_params["dropout_rate"])
Loss Function
FourCastNet uses relative root mean squared error for model training.
[6]:
loss_fn = RelativeRMSELoss()
Model Training
In this tutorial, we inherite the Trainer and override the get_callback member function so that we can perform inference on the test dataset during the training process.
With MindSpore version >= 1.8.1, we can use the functional programming for training neural networks. MindEarth provide a training interface for model training.
[7]:
class FCNTrainer(Trainer):
def __init__(self, config, model, loss_fn, logger):
super(FCNTrainer, self).__init__(config, model, loss_fn, logger)
self.pred_cb = self.get_callback()
def get_callback(self):
pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)
return pred_cb
trainer = FCNTrainer(config, model, loss_fn, logger)
2023-08-21 07:34:55,267 - pretrain.py[line:211] - INFO: steps_per_epoch: 404
[8]:
trainer.train()
epoch: 1 step: 404, loss is 0.5348429
Train epoch time: 136480.515 ms, per step time: 337.823 ms
epoch: 2 step: 404, loss is 0.35937342
Train epoch time: 60902.627 ms, per step time: 150.749 ms
epoch: 3 step: 404, loss is 0.33921248
Train epoch time: 60737.844 ms, per step time: 150.341 ms
...
epoch: 98 step: 404, loss is 0.15447393
Train epoch time: 61055.706 ms, per step time: 151.128 ms
epoch: 99 step: 404, loss is 0.15696357
Train epoch time: 60850.156 ms, per step time: 150.619 ms
epoch: 100 step: 404, loss is 0.15654306
Train epoch time: 60944.369 ms, per step time: 150.852 ms
2023-09-07 04:27:02,837 - forecast.py[line:209] - INFO: ================================Start Evaluation================================
2023-09-07 04:28:25,277 - forecast.py[line:177] - INFO: t = 6 hour:
2023-09-07 04:28:25,277 - forecast.py[line:188] - INFO: RMSE of Z500: 154.07894852240838, T2m: 2.0995438696856965, T850: 1.3081689948838815, U10: 1.527248748050362
2023-09-07 04:28:25,278 - forecast.py[line:189] - INFO: ACC of Z500: 0.9989880649296732, T2m: 0.9930711917863625, T850: 0.9954355203713009, U10: 0.9615764420500764
2023-09-07 04:28:25,279 - forecast.py[line:177] - INFO: t = 72 hour:
2023-09-07 04:28:25,279 - forecast.py[line:188] - INFO: RMSE of Z500: 885.3778200063341, T2m: 4.586325958437852, T850: 4.2593739999338736, U10: 4.75655467109408
2023-09-07 04:28:25,280 - forecast.py[line:189] - INFO: ACC of Z500: 0.9598951919101183, T2m: 0.9658168304842388, T850: 0.9501612262744354, U10: 0.6175327930007481
2023-09-07 04:28:25,281 - forecast.py[line:177] - INFO: t = 120 hour:
2023-09-07 04:28:25,281 - forecast.py[line:188] - INFO: RMSE of Z500: 1291.3199606908572, T2m: 6.734047767054735, T850: 5.6420206614200294, U10: 5.637643311177468
2023-09-07 04:28:25,282 - forecast.py[line:189] - INFO: ACC of Z500: 0.9150022892106006, T2m: 0.9294266102808937, T850: 0.9148957221265037, U10: 0.47971871343985495
2023-09-07 04:28:25,283 - forecast.py[line:237] - INFO: ================================End Evaluation================================
Model Evaluation and Visualization
After training, we use the 100th checkpoint for inference.
[9]:
pred_time_index = 0
params = load_checkpoint('./summary/ckpt/step_1/FourCastNet_1-100_404.ckpt')
load_param_into_net(model, params)
inference_module = InferenceModule(model, config, logger)
[10]:
def plt_data(pred, label, root_dir, index=0):
""" Visualize the forecast results """
std = np.load(os.path.join(root_dir, 'statistic/std.npy'))
mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))
std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))
mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))
pred, label = pred[index].asnumpy(), label.asnumpy()[..., index, :, :]
plt.figure(num='e_imshow', figsize=(100, 50), dpi=100)
plt.subplot(4, 3, 1)
plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth') # Z500
plt.subplot(4, 3, 2)
plt_global_field_data(pred, 'Z500', std, mean, 'Pred') # Z500
plt.subplot(4, 3, 3)
plt_global_field_data(label - pred, 'Z500', std, mean, 'Error') # Z500
plt.subplot(4, 3, 4)
plt_global_field_data(label, 'T850', std, mean, 'Ground Truth') # T850
plt.subplot(4, 3, 5)
plt_global_field_data(pred, 'T850', std, mean, 'Pred') # T850
plt.subplot(4, 3, 6)
plt_global_field_data(label - pred, 'T850', std, mean, 'Error') # T850
plt.subplot(4, 3, 7)
plt_global_field_data(label, 'U10', std_s, mean_s,
'Ground Truth', is_surface=True) # U10
plt.subplot(4, 3, 8)
plt_global_field_data(pred, 'U10', std_s, mean_s,
'Pred', is_surface=True) # U10
plt.subplot(4, 3, 9)
plt_global_field_data(label - pred, 'U10', std_s,
mean_s, 'Error', is_surface=True) # U10
plt.subplot(4, 3, 10)
plt_global_field_data(label, 'T2M', std_s, mean_s,
'Ground Truth', is_surface=True) # T2M
plt.subplot(4, 3, 11)
plt_global_field_data(pred, 'T2M', std_s, mean_s,
'Pred', is_surface=True) # T2M
plt.subplot(4, 3, 12)
plt_global_field_data(label - pred, 'T2M', std_s,
mean_s, 'Error', is_surface=True) # T2M
plt.savefig(f'pred_result.png', bbox_inches='tight')
plt.show()
[11]:
test_dataset_generator = Era5Data(data_params=config["data"], run_mode='test')
test_dataset = Dataset(test_dataset_generator, distribute=False,
num_workers=config["data"]['num_workers'], shuffle=False)
test_dataset = test_dataset.create_dataset(config["data"]['batch_size'])
data = next(test_dataset.create_dict_iterator())
inputs = data['inputs']
labels = data['labels']
pred = inference_module.forecast(inputs)
plt_data(pred, labels, config['data']['root_dir'])
The visualization of predictions by the 100th checkpoint, ground truth and their error is shown below.