mindearth.module.Trainer
- class mindearth.module.Trainer(config, model, loss_fn, logger=None, weather_data_source='ERA5', loss_scale=DynamicLossScaleManager())[source]
Base class of Weather Forecast model training. All user-define forecast model should be inherited from this class during training. This class generates datasets, optimizer, callbacks, and solver components based on the input model, loss function, and related configurations. For example, if you want to train your model, you can rewrite the get_dataset(), get_optimizer(), or other member functions to suit your needs, or instantiate the class directly. Then you can use the Trainer.train() function to start model training.
- Parameters
config (dict) – configurations of model, dataset, train details, etc.
model (mindspore.nn.Cell) – network for training.
loss_fn (mindspore.nn.Cell) – loss function.
logger (logging.RootLogger, optional) – logger of the training process. Default: None.
weatherdata_type (str, optional) – the dataset type. Default: ‘Era5Data’.
loss_scale (mindspore.amp.LossScaleManager, optional) – the class of loss scale manager when using mixed precision. Default: mindspore.amp.DynamicLossScaleManager().
- Raises
TypeError – If model or loss_fn is not mindspore.nn.Cell.
NotImplementedError – If the member function get_callback is not implemented.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, nn >>> from mindearth.module import Trainer >>> from mindearth.core import RelativeRMSELoss ... >>> class Net(nn.Cell): >>> def __init__(self, input_dim, output_dim): >>> super(Net, self).__init__() >>> self.fc1 = nn.Dense(input_dim, 128, weight_init='ones') >>> self.fc2 = nn.Dense(128, output_dim, weight_init='ones') ... >>> def construct(self, x): >>> x = x.transpose(0, 2, 3, 1) >>> x = self.fc1(x) >>> x = self.fc2(x) >>> x = x.transpose(0, 3, 1, 2) >>> return x ... >>> loss_fn = RelativeRMSELoss() >>> config={ ... "model": { ... 'name': 'Net' ... }, ... "data": { ... 'name': 'era5', ... 'root_dir': './dataset', ... 'feature_dims': 69, ... 't_in': 1, ... 't_out_train': 1, ... 't_out_valid': 20, ... 't_out_test': 20, ... 'train_interval': 1, ... 'valid_interval': 1, ... 'test_interval': 1, ... 'pred_lead_time': 6, ... 'data_frequency': 6, ... 'train_period': [2015, 2015], ... 'valid_period': [2016, 2016], ... 'test_period': [2017, 2017], ... 'patch': True, ... 'patch_size': 8, ... 'batch_size': 8, ... 'num_workers': 1, ... 'grid_resolution': 1.4, ... 'h_size': 128, ... 'w_size': 256 ... }, ... "optimizer": { ... 'name': 'adam', ... 'weight_decay': 0.0, ... 'epochs': 200, ... 'finetune_epochs': 1, ... 'warmup_epochs': 1, ... 'initial_lr': 0.0005 ... }, ... "summary": { ... 'save_checkpoint_steps': 1, ... 'keep_checkpoint_max': 10, ... 'valid_frequency': 10, ... 'summary_dir': '/path/to/summary', ... 'ckpt_path': '.' ... }, ... "train": { ... 'name': 'oop', ... 'distribute': False, ... 'device_id': 2, ... 'amp_level': 'O2', ... 'run_mode': 'test', ... 'load_ckpt': False ... } ... } ... >>> model = Net(input_dim = config['data']['feature_dims'], output_dim = config['data']['feature_dims']) >>> trainer = Trainer(config, model, loss_fn) >>> trainer.train()
- get_callback()[source]
Used to build a Callback class. You can use this mechanism to do some custom operations.
- get_checkpoint()[source]
Get the checkpoint callback of the model.
- Returns
Callback, The checkpoint callback of the model.
- get_dataset()[source]
Get train and valid dataset.
- Returns
Dataset, train dataset. Dataset, valid dataset.