mindearth.module.WeatherForecast

View Source On Gitee
class mindearth.module.WeatherForecast(model, config, logger)[source]

Base class of Weather Forecast model inference. All user-define forecast model should be inherited from this class during inference. This class can be called in the callback of the trainer or during inference through loading the checkpoint. By calling this class, the model can perform inference based on the input model using the custom forecast member function. t_out defines the number of forward inference passes to be made by the model.

Parameters
  • model (mindspore.nn.Cell) – the network for training.

  • config (dict) – the configurations of model, dataset, train details, etc.

  • logger (logging.RootLogger) – Logger of the training process.

Note

  • The member function, forecast, must be overridden to define the forward inference process of the model.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import logging
>>> from mindspore import Tensor, nn
>>> import mindspore
>>> from mindearth.data import Era5Data,Dataset
>>> from mindearth.module import WeatherForecast
...
>>> class Net(nn.Cell):
>>>     def __init__(self, in_channels, out_channels):
>>>         super(Net, self).__init__()
>>>         self.fc1 = nn.Dense(in_channels, 128, weight_init='ones')
>>>         self.fc2 = nn.Dense(128, out_channels, weight_init='ones')
...
>>>     def construct(self, x):
>>>         x = x.transpose(0, 2, 3, 1)
>>>         x = self.fc1(x)
>>>         x = self.fc2(x)
>>>         x = x.transpose(0, 3, 1, 2)
>>>         return x
...
>>> class InferenceModule(WeatherForecast):
>>>     def __init__(self, model, config, logger):
>>>         super(InferenceModule, self).__init__(model, config, logger)
...
>>>     def forecast(self, inputs, labels=None):
>>>         pred_lst = []
>>>         for t in range(self.t_out):
>>>             pred = self.model(inputs)
>>>             pred_lst.append(pred)
>>>             inputs = pred
>>>         return pred_lst
...
>>> config={
...     "model": {
...         'name': 'Net'
...     },
...     "data": {
...         'name': 'era5',
...         'root_dir': './dataset',
...         'feature_dims': 69,
...         't_in': 1,
...         't_out_train': 1,
...         't_out_valid': 20,
...         't_out_test': 20,
...         'valid_interval': 1,
...         'test_interval': 1,
...         'train_interval': 1,
...         'pred_lead_time': 6,
...         'data_frequency': 6,
...         'train_period': [2015, 2015],
...         'valid_period': [2016, 2016],
...         'test_period': [2017, 2017],
...         'patch': True,
...         'patch_size': 8,
...         'batch_size': 8,
...         'num_workers': 1,
...         'grid_resolution': 1.4,
...         'h_size': 128,
...         'w_size': 256
...     },
...     "optimizer": {
...         'name': 'adam',
...         'weight_decay': 0.0,
...         'epochs': 200,
...         'finetune_epochs': 1,
...         'warmup_epochs': 1,
...         'initial_lr': 0.0005
...     },
...     "summary": {
...         'save_checkpoint_steps': 1,
...         'keep_checkpoint_max': 10,
...         'valid_frequency': 10,
...         'summary_dir': '/path/to/summary',
...         'ckpt_path': '/path/to/ckpt'
...     },
...     "train": {
...         'name': 'oop',
...         'distribute': False,
...         'device_id': 2,
...         'amp_level': 'O2',
...         'run_mode': 'test',
...         'load_ckpt': True
...     }
... }
...
>>> model = Net(in_channels = config['data']['feature_dims'], out_channels = config['data']['feature_dims'])
>>> infer_module = InferenceModule(model, config,logging.getLogger())
>>> test_dataset_generator = Era5Data(data_params=config['data'], run_mode='test')
>>> test_dataset = Dataset(test_dataset_generator, distribute=config['train']['distribute'],
...                        num_workers = config['data']['num_workers'], shuffle=False)
>>> test_dataset = test_dataset.create_dataset(1)
>>> infer_module.eval(test_dataset)
compute_total_rmse_acc(dataset, generator_flag)[source]

Compute the total Root Mean Square Error (RMSE) and Accuracy for the dataset.

This function iterates over the dataset, calculates the RMSE and accuracy for each batch, and accumulates the results to compute the total RMSE and accuracy over the entire dataset.

Parameters
  • dataset (Dataset) – The dataset object to compute metrics for.

  • generator_flag (bool) – A flag indicating whether to use a data generator or not.

Returns

A tuple containing the total accuracy and RMSE for the dataset.

Raises

NotImplementedError – If an unsupported data source is specified.

eval(dataset, generator_flag=False)[source]

Eval the model using test dataset or validation dataset.

Parameters
  • dataset (mindspore.dataset) – The dataset for eval, including inputs and labels.

  • generator_flag (bool) – "generator_flag" is used to pass a parameter to the "compute_total_rmse_acc" method. A flag indicating whether to use a data generator or not.

static forecast(inputs, labels=None)[source]

The forecast function of the model.

Parameters
  • inputs (Tensor) – The input data of model.

  • labels (Tensor) – True values of the samples.