mindearth.module.Trainer

View Source On Gitee
class mindearth.module.Trainer(config, model, loss_fn, logger=None, weather_data_source='ERA5', loss_scale=DynamicLossScaleManager())[source]

Base class of Weather Forecast model training. All user-define forecast model should be inherited from this class during training. This class generates datasets, optimizer, callbacks, and solver components based on the input model, loss function, and related configurations. For example, if you want to train your model, you can rewrite the get_dataset(), get_optimizer(), or other member functions to suit your needs, or instantiate the class directly. Then you can use the Trainer.train() function to start model training.

Parameters
  • config (dict) – configurations of model, dataset, train details, etc.

  • model (mindspore.nn.Cell) – network for training.

  • loss_fn (mindspore.nn.Cell) – loss function.

  • logger (logging.RootLogger, optional) – logger of the training process. Default: None.

  • weatherdata_type (str, optional) – the dataset type. Default: 'Era5Data'.

  • loss_scale (mindspore.amp.LossScaleManager, optional) – the class of loss scale manager when using mixed precision. Default: mindspore.amp.DynamicLossScaleManager().

Raises
  • TypeError – If model or loss_fn is not mindspore.nn.Cell.

  • NotImplementedError – If the member function get_callback is not implemented.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, nn
>>> from mindearth.module import Trainer
>>> from mindearth.core import RelativeRMSELoss
...
>>> class Net(nn.Cell):
>>>     def __init__(self, input_dim, output_dim):
>>>         super(Net, self).__init__()
>>>         self.fc1 = nn.Dense(input_dim, 128, weight_init='ones')
>>>         self.fc2 = nn.Dense(128, output_dim, weight_init='ones')
...
>>>     def construct(self, x):
>>>         x = x.transpose(0, 2, 3, 1)
>>>         x = self.fc1(x)
>>>         x = self.fc2(x)
>>>         x = x.transpose(0, 3, 1, 2)
>>>         return x
...
>>> loss_fn = RelativeRMSELoss()
>>> config={
...     "model": {
...         'name': 'Net'
...     },
...     "data": {
...         'name': 'era5',
...         'root_dir': './dataset',
...         'feature_dims': 69,
...         't_in': 1,
...         't_out_train': 1,
...         't_out_valid': 20,
...         't_out_test': 20,
...         'train_interval': 1,
...         'valid_interval': 1,
...         'test_interval': 1,
...         'pred_lead_time': 6,
...         'data_frequency': 6,
...         'train_period': [2015, 2015],
...         'valid_period': [2016, 2016],
...         'test_period': [2017, 2017],
...         'patch': True,
...         'patch_size': 8,
...         'batch_size': 8,
...         'num_workers': 1,
...         'grid_resolution': 1.4,
...         'h_size': 128,
...         'w_size': 256
...     },
...     "optimizer": {
...         'name': 'adam',
...         'weight_decay': 0.0,
...         'epochs': 200,
...         'finetune_epochs': 1,
...         'warmup_epochs': 1,
...         'initial_lr': 0.0005
...     },
...     "summary": {
...         'save_checkpoint_steps': 1,
...         'keep_checkpoint_max': 10,
...         'valid_frequency': 10,
...         'summary_dir': '/path/to/summary',
...         'ckpt_path': '.'
...     },
...     "train": {
...         'name': 'oop',
...         'distribute': False,
...         'device_id': 2,
...         'amp_level': 'O2',
...         'run_mode': 'test',
...         'load_ckpt': False
...     }
... }
...
>>> model = Net(input_dim = config['data']['feature_dims'], output_dim = config['data']['feature_dims'])
>>> trainer = Trainer(config, model, loss_fn)
>>> trainer.train()
get_callback()[source]

Used to build a Callback class. You can use this mechanism to do some custom operations.

get_checkpoint()[source]

Get the checkpoint callback of the model.

Returns

Callback, The checkpoint callback of the model.

get_data_generator()[source]

Generate data generators for training and validation datasets.

The function creates data generators based on the specified weather data source. It supports 'ERA5' and 'DemSR' data sources, and will raise an error for unsupported sources.

Returns

A tuple containing the training and validation data generators.

Raises

NotImplementedError – If an unsupported data source is specified.

get_dataset()[source]

Get train and valid dataset.

Returns

Dataset, train dataset. Dataset, valid dataset.

get_optimizer()[source]

Get the training optimizer.

Returns

Optimizer, Optimizer of the model.

get_solver()[source]

Get the model solver for training.

Returns

Model, the model solver for training.

train()[source]

Execute model training.