mindearth.module.Trainer
- class mindearth.module.Trainer(config, model, loss_fn, logger=None, weather_data_source='ERA5', loss_scale=DynamicLossScaleManager())[源代码]
Trainer类是气象预测模型训练的基类。 所有用户自定义的预测模型训练都应该继承Trainer类。 Trainer类根据模型输入、损失函数和相关参数生成了datasets, optimizer, callbacks, 和solver模块。例如,如果需要训练自定义模型时,可以重写get_dataset(), get_optimizer()或其他方法来满足自定义需求,或者直接实例化Trainer类。 然后可以使用Trainer.train()方法开始训练模型。
- 参数:
config (dict) - 输入参数。例如,模型参数、数据参数、训练参数。
model (mindspore.nn.Cell) - 用于训练的网络。
loss_fn (mindspore.nn.Cell) - 损失函数。
logger (logging.RootLogger, 可选) - 训练过程中的日志模块。默认值:
None
。weatherdata_type (str, 可选) - 数据的类型。默认值:
Era5Data
。loss_scale (mindspore.amp.LossScaleManager, 可选) - 使用混合精度时,用于管理损失缩放系数的类。默认值:
mindspore.amp.DynamicLossScaleManager()
。
- 异常:
TypeError - 如果 model 或 loss_fn 不是mindspore.nn.Cell。
NotImplementedError - 如果 get_callback 的方法没有实现。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, nn >>> from mindearth.module import Trainer >>> from mindearth.core import RelativeRMSELoss ... >>> class Net(nn.Cell): >>> def __init__(self, input_dim, output_dim): >>> super(Net, self).__init__() >>> self.fc1 = nn.Dense(input_dim, 128, weight_init='ones') >>> self.fc2 = nn.Dense(128, output_dim, weight_init='ones') ... >>> def construct(self, x): >>> x = x.transpose(0, 2, 3, 1) >>> x = self.fc1(x) >>> x = self.fc2(x) >>> x = x.transpose(0, 3, 1, 2) >>> return x ... >>> loss_fn = RelativeRMSELoss() >>> config={ ... "model": { ... 'name': 'Net' ... }, ... "data": { ... 'name': 'era5', ... 'root_dir': './dataset', ... 'feature_dims': 69, ... 't_in': 1, ... 't_out_train': 1, ... 't_out_valid': 20, ... 't_out_test': 20, ... 'train_interval': 1, ... 'valid_interval': 1, ... 'test_interval': 1, ... 'pred_lead_time': 6, ... 'data_frequency': 6, ... 'train_period': [2015, 2015], ... 'valid_period': [2016, 2016], ... 'test_period': [2017, 2017], ... 'patch': True, ... 'patch_size': 8, ... 'batch_size': 8, ... 'num_workers': 1, ... 'grid_resolution': 1.4, ... 'h_size': 128, ... 'w_size': 256 ... }, ... "optimizer": { ... 'name': 'adam', ... 'weight_decay': 0.0, ... 'epochs': 200, ... 'finetune_epochs': 1, ... 'warmup_epochs': 1, ... 'initial_lr': 0.0005 ... }, ... "summary": { ... 'save_checkpoint_steps': 1, ... 'keep_checkpoint_max': 10, ... 'valid_frequency': 10, ... 'summary_dir': '/path/to/summary', ... 'ckpt_path': '.' ... }, ... "train": { ... 'name': 'oop', ... 'distribute': False, ... 'device_id': 2, ... 'amp_level': 'O2', ... 'run_mode': 'test', ... 'load_ckpt': False ... } ... } ... >>> model = Net(input_dim = config['data']['feature_dims'], output_dim = config['data']['feature_dims']) >>> trainer = Trainer(config, model, loss_fn) >>> trainer.train()