mindearth.module.WeatherForecast
- class mindearth.module.WeatherForecast(model, config, logger)[source]
Base class of Weather Forecast model inference. All user-define forecast model should be inherited from this class during inference. This class can be called in the callback of the trainer or during inference through loading the checkpoint. By calling this class, the model can perform inference based on the input model using the custom forecast member function. t_out defines the number of forward inference passes to be made by the model.
- Parameters
model (mindspore.nn.Cell) – the network for training.
config (dict) – the configurations of model, dataset, train details, etc.
logger (logging.RootLogger) – Logger of the training process.
Note
The member function, forecast, must be overridden to define the forward inference process of the model.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import logging >>> from mindspore import Tensor, nn >>> import mindspore >>> from mindearth.data import Era5Data,Dataset >>> from mindearth.module import WeatherForecast ... >>> class Net(nn.Cell): >>> def __init__(self, in_channels, out_channels): >>> super(Net, self).__init__() >>> self.fc1 = nn.Dense(in_channels, 128, weight_init='ones') >>> self.fc2 = nn.Dense(128, out_channels, weight_init='ones') ... >>> def construct(self, x): >>> x = x.transpose(0, 2, 3, 1) >>> x = self.fc1(x) >>> x = self.fc2(x) >>> x = x.transpose(0, 3, 1, 2) >>> return x ... >>> class InferenceModule(WeatherForecast): >>> def __init__(self, model, config, logger): >>> super(InferenceModule, self).__init__(model, config, logger) ... >>> def forecast(self, inputs, labels=None): >>> pred_lst = [] >>> for t in range(self.t_out): >>> pred = self.model(inputs) >>> pred_lst.append(pred) >>> inputs = pred >>> return pred_lst ... >>> config={ ... "model": { ... 'name': 'Net' ... }, ... "data": { ... 'name': 'era5', ... 'root_dir': './dataset', ... 'feature_dims': 69, ... 't_in': 1, ... 't_out_train': 1, ... 't_out_valid': 20, ... 't_out_test': 20, ... 'valid_interval': 1, ... 'test_interval': 1, ... 'train_interval': 1, ... 'pred_lead_time': 6, ... 'data_frequency': 6, ... 'train_period': [2015, 2015], ... 'valid_period': [2016, 2016], ... 'test_period': [2017, 2017], ... 'patch': True, ... 'patch_size': 8, ... 'batch_size': 8, ... 'num_workers': 1, ... 'grid_resolution': 1.4, ... 'h_size': 128, ... 'w_size': 256 ... }, ... "optimizer": { ... 'name': 'adam', ... 'weight_decay': 0.0, ... 'epochs': 200, ... 'finetune_epochs': 1, ... 'warmup_epochs': 1, ... 'initial_lr': 0.0005 ... }, ... "summary": { ... 'save_checkpoint_steps': 1, ... 'keep_checkpoint_max': 10, ... 'valid_frequency': 10, ... 'summary_dir': '/path/to/summary', ... 'ckpt_path': '/path/to/ckpt' ... }, ... "train": { ... 'name': 'oop', ... 'distribute': False, ... 'device_id': 2, ... 'amp_level': 'O2', ... 'run_mode': 'test', ... 'load_ckpt': True ... } ... } ... >>> model = Net(in_channels = config['data']['feature_dims'], out_channels = config['data']['feature_dims']) >>> infer_module = InferenceModule(model, config,logging.getLogger()) >>> test_dataset_generator = Era5Data(data_params=config['data'], run_mode='test') >>> test_dataset = Dataset(test_dataset_generator, distribute=config['train']['distribute'], ... num_workers = config['data']['num_workers'], shuffle=False) >>> test_dataset = test_dataset.create_dataset(1) >>> infer_module.eval(test_dataset)
- compute_total_rmse_acc(dataset, generator_flag)[source]
Compute the total Root Mean Square Error (RMSE) and Accuracy for the dataset.
This function iterates over the dataset, calculates the RMSE and accuracy for each batch, and accumulates the results to compute the total RMSE and accuracy over the entire dataset.
- Parameters
- Returns
A tuple containing the total accuracy and RMSE for the dataset.
- Raises
NotImplementedError – If an unsupported data source is specified.
- eval(dataset, generator_flag=False)[source]
Eval the model using test dataset or validation dataset.
- Parameters
dataset (mindspore.dataset) – The dataset for eval, including inputs and labels.
generator_flag (bool) – "generator_flag" is used to pass a parameter to the "compute_total_rmse_acc" method. A flag indicating whether to use a data generator or not.