mindearth.module.WeatherForecast

View Source On Gitee
class mindearth.module.WeatherForecast(model, config, logger)[source]

Base class of Weather Forecast model inference. All user-define forecast model should be inherited from this class during inference. This class can be called in the callback of the trainer or during inference through loading the checkpoint. By calling this class, the model can perform inference based on the input model using the custom forecast member function. t_out_test defines the number of forward inference passes to be made by the model.

Parameters
  • model (mindspore.nn.Cell) – the network for training.

  • config (dict) – the configurations of model, dataset, train details, etc.

  • logger (logging.RootLogger) – Logger of the training process.

Note

  • The member function, forecast, must be overridden to define the forward inference process of the model.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import logging
>>> from mindspore import Tensor, nn
>>> import mindspore
>>> from mindearth.data import Era5Data,Dataset
>>> from mindearth.module import WeatherForecast
...
>>> class Net(nn.Cell):
>>>     def __init__(self, in_channels, out_channels):
>>>         super(Net, self).__init__()
>>>         self.fc1 = nn.Dense(in_channels, 128, weight_init='ones')
>>>         self.fc2 = nn.Dense(128, out_channels, weight_init='ones')
...
>>>     def construct(self, x):
>>>         x = x.transpose(0, 2, 3, 1)
>>>         x = self.fc1(x)
>>>         x = self.fc2(x)
>>>         x = x.transpose(0, 3, 1, 2)
>>>         return x
...
>>> class InferenceModule(WeatherForecast):
>>>     def __init__(self, model, config, logger):
>>>         super(InferenceModule, self).__init__(model, config, logger)
...
>>>     def forecast(self, inputs, labels=None):
>>>         pred_lst = []
>>>         for t in range(self.t_out_test):
>>>             pred = self.model(inputs)
>>>             pred_lst.append(pred)
>>>             inputs = pred
>>>         return pred_lst
...
>>> config={
...     "model": {
...         'name': 'Net'
...     },
...     "data": {
...         'name': 'era5',
...         'root_dir': './dataset',
...         'feature_dims': 69,
...         't_in': 1,
...         't_out_train': 1,
...         't_out_valid': 20,
...         't_out_test': 20,
...         'valid_interval': 1,
...         'test_interval': 1,
...         'train_interval': 1,
...         'pred_lead_time': 6,
...         'data_frequency': 6,
...         'train_period': [2015, 2015],
...         'valid_period': [2016, 2016],
...         'test_period': [2017, 2017],
...         'patch': True,
...         'patch_size': 8,
...         'batch_size': 8,
...         'num_workers': 1,
...         'grid_resolution': 1.4,
...         'h_size': 128,
...         'w_size': 256
...     },
...     "optimizer": {
...         'name': 'adam',
...         'weight_decay': 0.0,
...         'epochs': 200,
...         'finetune_epochs': 1,
...         'warmup_epochs': 1,
...         'initial_lr': 0.0005
...     },
...     "summary": {
...         'save_checkpoint_steps': 1,
...         'keep_checkpoint_max': 10,
...         'valid_frequency': 10,
...         'summary_dir': '/path/to/summary',
...         'ckpt_path': '/path/to/ckpt'
...     },
...     "train": {
...         'name': 'oop',
...         'distribute': False,
...         'device_id': 2,
...         'amp_level': 'O2',
...         'run_mode': 'test',
...         'load_ckpt': True
...     }
... }
...
>>> model = Net(in_channels = config['data']['feature_dims'], out_channels = config['data']['feature_dims'])
>>> infer_module = InferenceModule(model, config,logging.getLogger())
>>> test_dataset_generator = Era5Data(data_params=config['data'], run_mode='test')
>>> test_dataset = Dataset(test_dataset_generator, distribute=config['train']['distribute'],
...                        num_workers = config['data']['num_workers'], shuffle=False)
>>> test_dataset = test_dataset.create_dataset(1)
>>> infer_module.eval(test_dataset)
eval(dataset)[source]

Eval the model using test dataset or validation dataset.

Parameters

dataset (mindspore.dataset) – The dataset for eval, including inputs and labels.

static forecast(inputs, labels=None)[source]

The forecast function of the model.

Parameters
  • inputs (Tensor) – The input data of model.

  • labels (Tensor) – True values of the samples.