mindearth.module.Trainer

查看源文件
class mindearth.module.Trainer(config, model, loss_fn, logger=None, weather_data_source='ERA5', loss_scale=DynamicLossScaleManager())[源代码]

Trainer类是气象预测模型训练的基类。 所有用户自定义的预测模型训练都应该继承Trainer类。 Trainer类根据模型输入、损失函数和相关参数生成了datasets, optimizer, callbacks, 和solver模块。例如,如果需要训练自定义模型时,可以重写get_dataset(), get_optimizer()或其他方法来满足自定义需求,或者直接实例化Trainer类。 然后可以使用Trainer.train()方法开始训练模型。

参数:
  • config (dict) - 输入参数。例如,模型参数、数据参数、训练参数。

  • model (mindspore.nn.Cell) - 用于训练的网络。

  • loss_fn (mindspore.nn.Cell) - 损失函数。

  • logger (logging.RootLogger, 可选) - 训练过程中的日志模块。默认值: None

  • weatherdata_type (str, 可选) - 数据的类型。默认值: Era5Data

  • loss_scale (mindspore.amp.LossScaleManager, 可选) - 使用混合精度时,用于管理损失缩放系数的类。默认值: mindspore.amp.DynamicLossScaleManager()

异常:
  • TypeError - 如果 modelloss_fn 不是mindspore.nn.Cell。

  • NotImplementedError - 如果 get_callback 的方法没有实现。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, nn
>>> from mindearth.module import Trainer
>>> from mindearth.core import RelativeRMSELoss
...
>>> class Net(nn.Cell):
>>>     def __init__(self, input_dim, output_dim):
>>>         super(Net, self).__init__()
>>>         self.fc1 = nn.Dense(input_dim, 128, weight_init='ones')
>>>         self.fc2 = nn.Dense(128, output_dim, weight_init='ones')
...
>>>     def construct(self, x):
>>>         x = x.transpose(0, 2, 3, 1)
>>>         x = self.fc1(x)
>>>         x = self.fc2(x)
>>>         x = x.transpose(0, 3, 1, 2)
>>>         return x
...
>>> loss_fn = RelativeRMSELoss()
>>> config={
...     "model": {
...         'name': 'Net'
...     },
...     "data": {
...         'name': 'era5',
...         'root_dir': './dataset',
...         'feature_dims': 69,
...         't_in': 1,
...         't_out_train': 1,
...         't_out_valid': 20,
...         't_out_test': 20,
...         'train_interval': 1,
...         'valid_interval': 1,
...         'test_interval': 1,
...         'pred_lead_time': 6,
...         'data_frequency': 6,
...         'train_period': [2015, 2015],
...         'valid_period': [2016, 2016],
...         'test_period': [2017, 2017],
...         'patch': True,
...         'patch_size': 8,
...         'batch_size': 8,
...         'num_workers': 1,
...         'grid_resolution': 1.4,
...         'h_size': 128,
...         'w_size': 256
...     },
...     "optimizer": {
...         'name': 'adam',
...         'weight_decay': 0.0,
...         'epochs': 200,
...         'finetune_epochs': 1,
...         'warmup_epochs': 1,
...         'initial_lr': 0.0005
...     },
...     "summary": {
...         'save_checkpoint_steps': 1,
...         'keep_checkpoint_max': 10,
...         'valid_frequency': 10,
...         'summary_dir': '/path/to/summary',
...         'ckpt_path': '.'
...     },
...     "train": {
...         'name': 'oop',
...         'distribute': False,
...         'device_id': 2,
...         'amp_level': 'O2',
...         'run_mode': 'test',
...         'load_ckpt': False
...     }
... }
...
>>> model = Net(input_dim = config['data']['feature_dims'], output_dim = config['data']['feature_dims'])
>>> trainer = Trainer(config, model, loss_fn)
>>> trainer.train()
get_callback()[源代码]

用于定义模型的回调类。用户必须自定义重写该方法。

get_checkpoint()[源代码]

获得模型的checkpoint实例。

返回:

Callback,模型的checkpoint实例.

get_dataset()[源代码]

获得训练数据集和验证数据集。

返回:

Dataset,训练数据集。 Dataset,验证数据集。

get_optimizer()[源代码]

获得模型训练的优化器。

返回:

Optimizer,模型的优化器。

get_solver()[源代码]

获得模型训练的求解器。

返回:

Model,模型的求解器。

train()[源代码]

执行模型训练。

get_data_generator()[源代码]

生成用于训练和验证数据集的数据生成器。

该函数根据指定的天气数据源创建数据生成器。 支持 'ERA5' 和 'DemSR' 数据源,对于不支持的数据源将引发错误。

返回:
  • 包含训练和验证数据生成器的元组。

异常:
  • NotImplementedError - 如果指定了不支持的数据源。