mindearth.module.WeatherForecast
- class mindearth.module.WeatherForecast(model, config, logger)[source]
Base class of Weather Forecast model inference. All user-define forecast model should be inherited from this class during inference. This class can be called in the callback of the trainer or during inference through loading the checkpoint. By calling this class, the model can perform inference based on the input model using the custom forecast member function. t_out_test defines the number of forward inference passes to be made by the model.
- Parameters
model (mindspore.nn.Cell) – the network for training.
config (dict) – the configurations of model, dataset, train details, etc.
logger (logging.RootLogger) – Logger of the training process.
Note
The member function, forecast, must be overridden to define the forward inference process of the model.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import logging >>> from mindspore import Tensor, nn >>> import mindspore >>> from mindearth.data import Era5Data,Dataset >>> from mindearth.module import WeatherForecast ... >>> class Net(nn.Cell): >>> def __init__(self, in_channels, out_channels): >>> super(Net, self).__init__() >>> self.fc1 = nn.Dense(in_channels, 128, weight_init='ones') >>> self.fc2 = nn.Dense(128, out_channels, weight_init='ones') ... >>> def construct(self, x): >>> x = x.transpose(0, 2, 3, 1) >>> x = self.fc1(x) >>> x = self.fc2(x) >>> x = x.transpose(0, 3, 1, 2) >>> return x ... >>> class InferenceModule(WeatherForecast): >>> def __init__(self, model, config, logger): >>> super(InferenceModule, self).__init__(model, config, logger) ... >>> def forecast(self, inputs, labels=None): >>> pred_lst = [] >>> for t in range(self.t_out_test): >>> pred = self.model(inputs) >>> pred_lst.append(pred) >>> inputs = pred >>> return pred_lst ... >>> config={ ... "model": { ... 'name': 'Net' ... }, ... "data": { ... 'name': 'era5', ... 'root_dir': './dataset', ... 'feature_dims': 69, ... 't_in': 1, ... 't_out_train': 1, ... 't_out_valid': 20, ... 't_out_test': 20, ... 'valid_interval': 1, ... 'test_interval': 1, ... 'train_interval': 1, ... 'pred_lead_time': 6, ... 'data_frequency': 6, ... 'train_period': [2015, 2015], ... 'valid_period': [2016, 2016], ... 'test_period': [2017, 2017], ... 'patch': True, ... 'patch_size': 8, ... 'batch_size': 8, ... 'num_workers': 1, ... 'grid_resolution': 1.4, ... 'h_size': 128, ... 'w_size': 256 ... }, ... "optimizer": { ... 'name': 'adam', ... 'weight_decay': 0.0, ... 'epochs': 200, ... 'finetune_epochs': 1, ... 'warmup_epochs': 1, ... 'initial_lr': 0.0005 ... }, ... "summary": { ... 'save_checkpoint_steps': 1, ... 'keep_checkpoint_max': 10, ... 'valid_frequency': 10, ... 'summary_dir': '/path/to/summary', ... 'ckpt_path': '/path/to/ckpt' ... }, ... "train": { ... 'name': 'oop', ... 'distribute': False, ... 'device_id': 2, ... 'amp_level': 'O2', ... 'run_mode': 'test', ... 'load_ckpt': True ... } ... } ... >>> model = Net(in_channels = config['data']['feature_dims'], out_channels = config['data']['feature_dims']) >>> infer_module = InferenceModule(model, config,logging.getLogger()) >>> test_dataset_generator = Era5Data(data_params=config['data'], run_mode='test') >>> test_dataset = Dataset(test_dataset_generator, distribute=config['train']['distribute'], ... num_workers = config['data']['num_workers'], shuffle=False) >>> test_dataset = test_dataset.create_dataset(1) >>> infer_module.eval(test_dataset)