FuXi: Medium-range Global Weather Forecasting Based on Cascade Architecture
Overview
FuXi is a data-driven global weather forecast model developed by researchers from Fudan University. It provides medium-term forecasts of key global weather indicators with a resolution of 0.25°. Equivalent to a spatial resolution of approximately 25 km x 25 km near the equator and a global grid of 720 x 1440 pixels in size. Compared with the previous ML-based weather forecast model, the FuXi model using cascade architecture achieved excellent results in ECMWF.
This tutorial introduces the research background and technical path of FuXi, and shows how to train and fast infer the model through MindSpore Earth. More information can be found in paper. The ERA5_0_25_tiny400 with a resolution of 0.25° is used to provide a detailed introduction to the operational process of this tutorial.
FuXi
The basic Fuxi model architecture consists of three main components, as shown in the figure: Cube Embedding, U-Transformer, and fully connected layer. The input data combined the upper air and surface variables and created a data cube with dimensions 69×720×1440, with a time step as a step.
The high-dimensional input data is reduced by combining space-time cube embedding and converted into C×180×360. The main purpose of cube embedding is to reduce the spatial-temporal dimension of input data and reduce redundant information. U-Transformer then processes the embedded data and makes predictions using a simple fully connected layer, and the output is first reshaped to 69×720×1440.
Cube Embedding
In order to reduce the spatial and temporal dimensions of the input data and speed up the training process, the cube embedding method is applied.
Specifically, a space-time cube embedding uses a three-dimensional (3D) convolutional layer, a convolution kernel and a stride length are respectively 2x4x4 (equivalent to \(\frac{T}{2}×\frac{H}{2}×\frac{W}{2}\)), and a quantity of output channels is C. After the space-time cube embedding, Layer Norm is used to improve the stability of the training. The dimension of the resulting data cube is C×180×360.
U-Transformer
The U-Transformer also includes the downsampling and upsampling blocks of the U-Net model. The downsampling block, referred to as the Down Block in the figure, reduces the data dimension to C×90×180, thereby minimizing the computational and memory requirements of self-attention computing. The down block consists of a 3×3 2D convolutional layer with a step of 2 and a residual block with two 3×3 convolution layers. This is followed by a group normalization (GN) layer and a Sigmoid weighted activation function(SiLU). SiLU weighted activation function \(σ(x)×x\) is calculated by multiplying the Sigmoid function with its input.
The up-sampling block is referred to as the up-block in the figure. It uses the same residual block as the down-block, and also includes a 2D deconvolution. The core is 2, and the step is 2. Up Block scales the data size back to Cx180x360. In addition, a jump connection is included to connect the output of the Down Block to the output of the Transformer Block before feeding to the Up Block.
The intermediate structure is constructed from 18 repeated Swin Transformer blocks by using residual post-normalization instead of pre-normalization and scaled cosine attention instead of original dot product self-attention, Swin Transformer solves several problems that occur in training and applying large-scale Swin Transformer models, such as training instability.
Technology Path
MindSpore Earth solves the problem as follows:
Data Construction.
Model Construction.
Loss function.
Model Training.
Model Evaluation and Visualization.
[2]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
from mindspore import set_seed
from mindspore import context
from mindspore import load_checkpoint, load_param_into_net
The following src
can be downloaded in fuxi/src.
[3]:
from mindearth.utils import load_yaml_config, plt_global_field_data, make_dir
from mindearth.data import Dataset, Era5Data
from src import init_model, get_logger
from src import MAELossForMultiLabel, FuXiTrainer, CustomWithLossCell, InferenceModule
[4]:
set_seed(0)
np.random.seed(0)
random.seed(0)
You can get parameters of model, data and optimizer from config.
[5]:
config = load_yaml_config("configs/FuXi.yaml")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0)
Data Construction
Download the statistic, training and validation dataset from ERA5_0_25_tiny400 to ./dataset
.
The ./dataset
is hosted with the following directory structure:
.
├── statistic
│ ├── mean.npy
│ ├── mean_s.npy
│ ├── std.npy
│ ├── std_s.npy
│ └── climate_0.25.npy
├── train
│ └── 2015
├── train_static
│ └── 2015
├── train_surface
│ └── 2015
├── train_surface_static
│ └── 2015
├── valid
│ └── 2016
├── valid_static
│ └── 2016
├── valid_surface
│ └── 2016
├── valid_surface_static
│ └── 2016
Model Construction
Model initialization includes the number of Swin Transformer blocks and training parameters.
[6]:
make_dir(os.path.join(config['summary']["summary_dir"], "image"))
logger_obj = get_logger(config)
fuxi_model = init_model(config, "train")
2024-01-29 12:34:41,485 - utils.py[line:91] - INFO: {'name': 'FuXi', 'depths': 18, 'in_channels': 96, 'out_channels': 192}
2024-01-29 12:34:41,487 - utils.py[line:91] - INFO: {'name': 'era5', 'root_dir': '/data4/cq/ERA5_0_25/', 'feature_dims': 69, 'pressure_level_num': 13, 'level_feature_size': 5, 'surface_feature_size': 4, 'h_size': 720, 'w_size': 1440, 'data_sink': False, 'batch_size': 1, 't_in': 1, 't_out_train': 1, 't_out_valid': 20, 't_out_test': 20, 'train_interval': 6, 'valid_interval': 6, 'test_interval': 6, 'pred_lead_time': 6, 'data_frequency': 6, 'train_period': [2015, 2015], 'valid_period': [2016, 2016], 'test_period': [2017, 2017], 'num_workers': 1, 'grid_resolution': 0.25}
2024-01-29 12:34:41,488 - utils.py[line:91] - INFO: {'name': 'adam', 'initial_lr': 0.00025, 'finetune_lr': 1e-05, 'finetune_epochs': 1, 'warmup_epochs': 1, 'weight_decay': 0.0, 'loss_weight': 0.25, 'gamma': 0.5, 'epochs': 100}
2024-01-29 12:34:41,489 - utils.py[line:91] - INFO: {'summary_dir': './summary', 'eval_interval': 10, 'save_checkpoint_steps': 10, 'keep_checkpoint_max': 50, 'plt_key_info': True, 'key_info_timestep': [6, 72, 120], 'ckpt_path': '/data3/cq'}
2024-01-29 12:34:41,490 - utils.py[line:91] - INFO: {'name': 'oop', 'distribute': False, 'amp_level': 'O2', 'load_ckpt': True}
Loss Function
FuXi uses custom mean absolute error for model training.
[10]:
data_params = config.get('data')
optimizer_params = config.get('optimizer')
loss_fn = MAELossForMultiLabel(data_params=data_params, optimizer_params=optimizer_params)
loss_cell = CustomWithLossCell(backbone=fuxi_model, loss_fn=loss_fn)
Model Training
In this tutorial, we inherite the Trainer and override the get_solver
member function to build custom loss function, and override get_callback
member function to perform inference on the test dataset during the training process.
MindSpore Earth provides training and inference interface for model training with MindSpore
version >= 2.0.0.
[15]:
trainer = FuXiTrainer(config, fuxi_model, loss_cell, logger_obj)
trainer.train()
2023-12-29 08:46:01,280 - pretrain.py[line:215] - INFO: steps_per_epoch: 67
epoch: 1 step: 67, loss is 0.37644267
Train epoch time: 222879.994 ms, per step time: 3326.567 ms
epoch: 2 step: 67, loss is 0.26096737
Train epoch time: 216913.275 ms, per step time: 3237.512 ms
epoch: 3 step: 67, loss is 0.2587443
Train epoch time: 217870.765 ms, per step time: 3251.802 ms
epoch: 4 step: 67, loss is 0.2280185
Train epoch time: 218111.564 ms, per step time: 3255.396 ms
epoch: 5 step: 67, loss is 0.20605856
Train epoch time: 216881.674 ms, per step time: 3237.040 ms
epoch: 6 step: 67, loss is 0.20178188
Train epoch time: 218800.354 ms, per step time: 3265.677 ms
epoch: 7 step: 67, loss is 0.21064804
Train epoch time: 217554.571 ms, per step time: 3247.083 ms
epoch: 8 step: 67, loss is 0.20392722
Train epoch time: 217324.330 ms, per step time: 3243.647 ms
epoch: 9 step: 67, loss is 0.19890495
Train epoch time: 218374.032 ms, per step time: 3259.314 ms
epoch: 10 step: 67, loss is 0.2064792
Train epoch time: 217399.318 ms, per step time: 3244.766 ms
2023-12-29 09:22:23,927 - forecast.py[line:223] - INFO: ================================Start Evaluation================================
2023-12-29 09:24:51,246 - forecast.py[line:241] - INFO: test dataset size: 1
2023-12-29 09:24:51,248 - forecast.py[line:191] - INFO: t = 6 hour:
2023-12-29 09:24:51,250 - forecast.py[line:202] - INFO: RMSE of Z500: 313.254855370194, T2m: 2.911020155335285, T850: 1.6009748653510902, U10: 1.8822629694594444
2023-12-29 09:24:51,251 - forecast.py[line:203] - INFO: ACC of Z500: 0.9950579247892839, T2m: 0.98743929296225, T850: 0.9930489273077082, U10: 0.9441216196638477
2023-12-29 09:24:51,252 - forecast.py[line:191] - INFO: t = 72 hour:
2023-12-29 09:24:51,253 - forecast.py[line:202] - INFO: RMSE of Z500: 1176.8557892319443, T2m: 7.344694139181644, T850: 6.165706260104667, U10: 5.953978905254709
2023-12-29 09:24:51,254 - forecast.py[line:203] - INFO: ACC of Z500: 0.9271318752961824, T2m: 0.9236962494086007, T850: 0.9098796075852417, U10: 0.5003382663349598
2023-12-29 09:24:51,255 - forecast.py[line:191] - INFO: t = 120 hour:
2023-12-29 09:24:51,256 - forecast.py[line:202] - INFO: RMSE of Z500: 1732.662048442734, T2m: 9.891472332990181, T850: 8.233521390723434, U10: 7.434774900830313
2023-12-29 09:24:51,256 - forecast.py[line:203] - INFO: ACC of Z500: 0.8421506711992445, T2m: 0.8468635778030965, T850: 0.8467625693884427, U10: 0.3787509969898105
2023-12-29 09:24:51,257 - forecast.py[line:256] - INFO: ================================End Evaluation================================
......
epoch: 91 step: 67, loss is 0.13158562
Train epoch time: 191498.866 ms, per step time: 2858.192 ms
epoch: 92 step: 67, loss is 0.12776905
Train epoch time: 218376.797 ms, per step time: 3259.355 ms
epoch: 93 step: 67, loss is 0.12682373
Train epoch time: 217263.432 ms, per step time: 3242.738 ms
epoch: 94 step: 67, loss is 0.12594032
Train epoch time: 217970.325 ms, per step time: 3253.288 ms
epoch: 95 step: 67, loss is 0.12149178
Train epoch time: 217401.066 ms, per step time: 3244.792 ms
epoch: 96 step: 67, loss is 0.12223453
Train epoch time: 218616.899 ms, per step time: 3265.344 ms
epoch: 97 step: 67, loss is 0.12046164
Train epoch time: 218616.899 ms, per step time: 3263.949 ms
epoch: 98 step: 67, loss is 0.1172382
Train epoch time: 216666.521 ms, per step time: 3233.829 ms
epoch: 99 step: 67, loss is 0.11799482
Train epoch time: 218090.233 ms, per step time: 3255.078 ms
epoch: 100 step: 67, loss is 0.11112012
Train epoch time: 218108.888 ms, per step time: 3255.357 ms
2023-12-29 10:00:44,043 - forecast.py[line:223] - INFO: ================================Start Evaluation================================
2023-12-29 10:02:59,291 - forecast.py[line:241] - INFO: test dataset size: 1
2023-12-29 10:02:59,293 - forecast.py[line:191] - INFO: t = 6 hour:
2023-12-29 10:02:59,294 - forecast.py[line:202] - INFO: RMSE of Z500: 159.26790471459077, T2m: 1.7593914514223792, T850: 1.2225771108909576, U10: 1.3952338408157166
2023-12-29 10:02:59,295 - forecast.py[line:203] - INFO: ACC of Z500: 0.996888905697735, T2m: 0.9882202464019967, T850: 0.994542681351491, U10: 0.9697411543132562
2023-12-29 10:02:59,297 - forecast.py[line:191] - INFO: t = 72 hour:
2023-12-29 10:02:59,298 - forecast.py[line:202] - INFO: RMSE of Z500: 937.2960233810791, T2m: 5.177728653933931, T850: 4.831667457069809, U10: 5.30111109022694
2023-12-29 10:02:59,299 - forecast.py[line:203] - INFO: ACC of Z500: 0.9542952919181137, T2m: 0.9557775651851869, T850: 0.9371537322317006, U10: 0.5895038993694246
2023-12-29 10:02:59,300 - forecast.py[line:191] - INFO: t = 120 hour:
2023-12-29 10:02:59,301 - forecast.py[line:202] - INFO: RMSE of Z500: 1200.9140481697198, T2m: 6.913749261896835, T850: 6.530332262562704, U10: 6.3855645042672835
2023-12-29 10:02:59,303 - forecast.py[line:203] - INFO: ACC of Z500: 0.9257611031529911, T2m: 0.9197160039098073, T850: 0.8867113860499101, U10: 0.47483364671406136
2023-12-29 10:02:59,304 - forecast.py[line:256] - INFO: ================================End Evaluation================================
Model Evaluation and Visualization
After training, we use the 100th checkpoint for inference. The visualization of predictions, ground truth and their error is shown below.
[7]:
params = load_checkpoint('./FuXi_depths_18_in_channels_96_out_channels_192_recompute_True_adam_oop/ckpt/step_1/FuXi-100_67.ckpt')
load_param_into_net(fuxi_model, params)
inference_module = InferenceModule(fuxi_model, config, logger_obj)
[8]:
data_params = config.get("data")
test_dataset_generator = Era5Data(data_params=data_params, run_mode='test')
test_dataset = Dataset(test_dataset_generator, distribute=False,
num_workers=data_params.get('num_workers'), shuffle=False)
test_dataset = test_dataset.create_dataset(data_params.get('batch_size'))
data = next(test_dataset.create_dict_iterator())
inputs = data['inputs']
labels = data['labels']
[9]:
labels = labels[..., 0, :, :]
labels = labels.transpose(0, 2, 1)
labels = labels.reshape(labels.shape[0], labels.shape[1], data_params.get("h_size"), data_params.get("w_size")).asnumpy()
pred = inference_module.forecast(inputs)
pred = pred[0].transpose(1, 0)
pred = pred.reshape(pred.shape[0], data_params.get("h_size"), data_params.get("w_size")).asnumpy()
pred = np.expand_dims(pred, axis=0)
[11]:
def plt_key_info_comparison(pred, label, root_dir):
""" Visualize the comparison of forecast results """
std = np.load(os.path.join(root_dir, 'statistic/std.npy'))
mean = np.load(os.path.join(root_dir, 'statistic/mean.npy'))
std_s = np.load(os.path.join(root_dir, 'statistic/std_s.npy'))
mean_s = np.load(os.path.join(root_dir, 'statistic/mean_s.npy'))
plt.figure(num='e_imshow', figsize=(100, 50))
plt.subplot(4, 3, 1)
plt_global_field_data(label, 'Z500', std, mean, 'Ground Truth') # Z500
plt.subplot(4, 3, 2)
plt_global_field_data(pred, 'Z500', std, mean, 'Pred') # Z500
plt.subplot(4, 3, 3)
plt_global_field_data(label - pred, 'Z500', std, mean, 'Error', is_error=True) # Z500
plt.subplot(4, 3, 4)
plt_global_field_data(label, 'T850', std, mean, 'Ground Truth') # T850
plt.subplot(4, 3, 5)
plt_global_field_data(pred, 'T850', std, mean, 'Pred') # T850
plt.subplot(4, 3, 6)
plt_global_field_data(label - pred, 'T850', std, mean, 'Error', is_error=True) # T850
plt.subplot(4, 3, 7)
plt_global_field_data(label, 'U10', std_s, mean_s, 'Ground Truth', is_surface=True) # U10
plt.subplot(4, 3, 8)
plt_global_field_data(pred, 'U10', std_s, mean_s, 'Pred', is_surface=True) # U10
plt.subplot(4, 3, 9)
plt_global_field_data(label - pred, 'U10', std_s, mean_s, 'Error', is_surface=True, is_error=True) # U10
plt.subplot(4, 3, 10)
plt_global_field_data(label, 'T2M', std_s, mean_s, 'Ground Truth', is_surface=True) # T2M
plt.subplot(4, 3, 11)
plt_global_field_data(pred, 'T2M', std_s, mean_s, 'Pred', is_surface=True) # T2M
plt.subplot(4, 3, 12)
plt_global_field_data(label - pred, 'T2M', std_s, mean_s, 'Error', is_surface=True, is_error=True) # T2M
plt.savefig(f'key_info_comparison1.png', bbox_inches='tight')
plt.show()
[ ]:
plt_key_info_comparison(pred, labels, data_params.get('root_dir'))