mindearth.module.WeatherForecast
- class mindearth.module.WeatherForecast(model, config, logger)[源代码]
WeatherForecast类是气象预测模型推理的基类。 所有用户自定义的预测模型推理都应该继承WeatherForecast类。 WeatherForecast类可以在训练回调或推理通过加载模型参数后被调用。 通过调用WeatherForecast类,模型可以根据输入模型的自定义预测方法执行推理。t_out_test表示模型前向推理的次数。
- 参数:
model (mindspore.nn.Cell) - 用于训练的网络。
config (dict) - 输入参数。例如,模型参数、数据参数、训练参数。
logger (logging.RootLogger) - 训练过程中的日志模块。
说明
需要重写其中的成员函数 forecast 用于定义模型推理的前向过程。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import logging >>> from mindspore import Tensor, nn >>> import mindspore >>> from mindearth.data import Era5Data,Dataset >>> from mindearth.module import WeatherForecast ... >>> class Net(nn.Cell): >>> def __init__(self, in_channels, out_channels): >>> super(Net, self).__init__() >>> self.fc1 = nn.Dense(in_channels, 128, weight_init='ones') >>> self.fc2 = nn.Dense(128, out_channels, weight_init='ones') ... >>> def construct(self, x): >>> x = x.transpose(0, 2, 3, 1) >>> x = self.fc1(x) >>> x = self.fc2(x) >>> x = x.transpose(0, 3, 1, 2) >>> return x ... >>> class InferenceModule(WeatherForecast): >>> def __init__(self, model, config, logger): >>> super(InferenceModule, self).__init__(model, config, logger) ... >>> def forecast(self, inputs, labels=None): >>> pred_lst = [] >>> for t in range(self.t_out_test): >>> pred = self.model(inputs) >>> pred_lst.append(pred) >>> inputs = pred >>> return pred_lst ... >>> config={ ... "model": { ... 'name': 'Net' ... }, ... "data": { ... 'name': 'era5', ... 'root_dir': './dataset', ... 'feature_dims': 69, ... 't_in': 1, ... 't_out_train': 1, ... 't_out_valid': 20, ... 't_out_test': 20, ... 'valid_interval': 1, ... 'test_interval': 1, ... 'train_interval': 1, ... 'pred_lead_time': 6, ... 'data_frequency': 6, ... 'train_period': [2015, 2015], ... 'valid_period': [2016, 2016], ... 'test_period': [2017, 2017], ... 'patch': True, ... 'patch_size': 8, ... 'batch_size': 8, ... 'num_workers': 1, ... 'grid_resolution': 1.4, ... 'h_size': 128, ... 'w_size': 256 ... }, ... "optimizer": { ... 'name': 'adam', ... 'weight_decay': 0.0, ... 'epochs': 200, ... 'finetune_epochs': 1, ... 'warmup_epochs': 1, ... 'initial_lr': 0.0005 ... }, ... "summary": { ... 'save_checkpoint_steps': 1, ... 'keep_checkpoint_max': 10, ... 'valid_frequency': 10, ... 'summary_dir': '/path/to/summary', ... 'ckpt_path': '/path/to/ckpt' ... }, ... "train": { ... 'name': 'oop', ... 'distribute': False, ... 'device_id': 2, ... 'amp_level': 'O2', ... 'run_mode': 'test', ... 'load_ckpt': True ... } ... } ... >>> model = Net(in_channels = config['data']['feature_dims'], out_channels = config['data']['feature_dims']) >>> infer_module = InferenceModule(model, config,logging.getLogger()) >>> test_dataset_generator = Era5Data(data_params=config['data'], run_mode='test') >>> test_dataset = Dataset(test_dataset_generator, distribute=config['train']['distribute'], ... num_workers = config['data']['num_workers'], shuffle=False) >>> test_dataset = test_dataset.create_dataset(1) >>> infer_module.eval(test_dataset)