mindearth.module.WeatherForecast

查看源文件
class mindearth.module.WeatherForecast(model, config, logger)[源代码]

WeatherForecast类是气象预测模型推理的基类。 所有用户自定义的预测模型推理都应该继承WeatherForecast类。 WeatherForecast类可以在训练回调或推理通过加载模型参数后被调用。 通过调用WeatherForecast类,模型可以根据输入模型的自定义预测方法执行推理。t_out_test表示模型前向推理的次数。

参数:
  • model (mindspore.nn.Cell) - 用于训练的网络。

  • config (dict) - 输入参数。例如,模型参数、数据参数、训练参数。

  • logger (logging.RootLogger) - 训练过程中的日志模块。

说明

需要重写其中的成员函数 forecast 用于定义模型推理的前向过程。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import logging
>>> from mindspore import Tensor, nn
>>> import mindspore
>>> from mindearth.data import Era5Data,Dataset
>>> from mindearth.module import WeatherForecast
...
>>> class Net(nn.Cell):
>>>     def __init__(self, in_channels, out_channels):
>>>         super(Net, self).__init__()
>>>         self.fc1 = nn.Dense(in_channels, 128, weight_init='ones')
>>>         self.fc2 = nn.Dense(128, out_channels, weight_init='ones')
...
>>>     def construct(self, x):
>>>         x = x.transpose(0, 2, 3, 1)
>>>         x = self.fc1(x)
>>>         x = self.fc2(x)
>>>         x = x.transpose(0, 3, 1, 2)
>>>         return x
...
>>> class InferenceModule(WeatherForecast):
>>>     def __init__(self, model, config, logger):
>>>         super(InferenceModule, self).__init__(model, config, logger)
...
>>>     def forecast(self, inputs, labels=None):
>>>         pred_lst = []
>>>         for t in range(self.t_out):
>>>             pred = self.model(inputs)
>>>             pred_lst.append(pred)
>>>             inputs = pred
>>>         return pred_lst
...
>>> config={
...     "model": {
...         'name': 'Net'
...     },
...     "data": {
...         'name': 'era5',
...         'root_dir': './dataset',
...         'feature_dims': 69,
...         't_in': 1,
...         't_out_train': 1,
...         't_out_valid': 20,
...         't_out_test': 20,
...         'valid_interval': 1,
...         'test_interval': 1,
...         'train_interval': 1,
...         'pred_lead_time': 6,
...         'data_frequency': 6,
...         'train_period': [2015, 2015],
...         'valid_period': [2016, 2016],
...         'test_period': [2017, 2017],
...         'patch': True,
...         'patch_size': 8,
...         'batch_size': 8,
...         'num_workers': 1,
...         'grid_resolution': 1.4,
...         'h_size': 128,
...         'w_size': 256
...     },
...     "optimizer": {
...         'name': 'adam',
...         'weight_decay': 0.0,
...         'epochs': 200,
...         'finetune_epochs': 1,
...         'warmup_epochs': 1,
...         'initial_lr': 0.0005
...     },
...     "summary": {
...         'save_checkpoint_steps': 1,
...         'keep_checkpoint_max': 10,
...         'valid_frequency': 10,
...         'summary_dir': '/path/to/summary',
...         'ckpt_path': '/path/to/ckpt'
...     },
...     "train": {
...         'name': 'oop',
...         'distribute': False,
...         'device_id': 2,
...         'amp_level': 'O2',
...         'run_mode': 'test',
...         'load_ckpt': True
...     }
... }
...
>>> model = Net(in_channels = config['data']['feature_dims'], out_channels = config['data']['feature_dims'])
>>> infer_module = InferenceModule(model, config,logging.getLogger())
>>> test_dataset_generator = Era5Data(data_params=config['data'], run_mode='test')
>>> test_dataset = Dataset(test_dataset_generator, distribute=config['train']['distribute'],
...                        num_workers = config['data']['num_workers'], shuffle=False)
>>> test_dataset = test_dataset.create_dataset(1)
>>> infer_module.eval(test_dataset)
eval(dataset)[源代码]

根据验证集数据或测试集数据执行模型推理。

参数:
  • dataset (mindspore.dataset) - 模型推理数据集,包括输入值和样本值。

  • generator_flag (bool, 可选) - 用于向 "compute_total_rmse_acc" 方法传递一个参数。指示是否使用数据生成器。

static forecast(inputs, labels=None)[源代码]

模型的预测方法。

参数:
  • inputs (Tensor) - 模型的输入数据。

  • labels (Tensor) - 样本真实数据。默认值: None

compute_total_rmse_acc(dataset, generator_flag)[源代码]

计算数据集的总体均方根误差(RMSE)和准确率。

该函数遍历数据集,为每个批次计算RMSE和准确率, 并累加结果以计算整个数据集的总体RMSE和准确率。

参数:
  • dataset (Dataset) - 用于计算指标的数据集对象。

  • generator_flag (bool) - 一个标志,指示是否使用数据生成器。

返回:
  • 包含数据集的总体准确率和RMSE的元组。

异常:
  • NotImplementedError - 如果指定了不支持的数据源。