mindspore_rl.agent.trainer 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Implementation of trainer base class.
"""
import os

from mindspore_rl.utils.callback import CallbackParam, CallbackManager, TimeCallback
import mindspore.nn as nn
from mindspore.train.serialization import load_checkpoint, load_param_into_net


INIT = 1
COLLECT = 2
EVAL = 3


[文档]class Trainer(nn.Cell): r""" The trainer base class. It is a process class that provides the basic mode of training. Note: Reference to `dqn_trainer.py <https://gitee.com/mindspore/docs/blob/master/docs/reinforcement/docs/source_en/dqn.md #defining-the-dqntrainer-class>`_. Args: msrl(MSRL): the function handler class. """ def __init__(self, msrl): super(Trainer, self).__init__() self.msrl = msrl self.vars = {}
[文档] def train(self, episodes, callbacks=None, ckpt_path=None): """ The train method provides a standard training process, including the whole loop and callbacks. Users can inherit or overwrite as needed. Args: episodes(int): the number of training episodes. callbacks(Optional[list[Callback]]): List of callback objects. Default: ``None``. ckpt_path(Optional[str]): The checkpoint file path to init or restore net. Default: ``None``. """ cb_params = CallbackParam() cb_params.episodes_num = episodes # Move TimeCallback to the first to exclude the time of other callbacks. for item in callbacks: if isinstance(item, TimeCallback): callbacks.remove(item) callbacks.insert(0, item) # 1 Using `CallbackManager` to traverse each callback. with CallbackManager(callbacks) as callback_list: # 2 Init or restore the variables if the checkpoint files exist. self._init_or_restore(ckpt_path) cb_params.cur_episode = 0 if self.vars: cb_params.vars = self.vars callback_list.begin(cb_params) # 3 Get `evaluate` function if meet the conditions. if 'eval_rate' in cb_params and cb_params.eval_rate > 0: cb_params.evaluate = self.evaluate for i in range(episodes): callback_list.episode_begin(cb_params) # 4 Get the result of `train_one_episode` func, and deal with three situation: # a) Default using: Three objects in tuple, each stand for `loss`, `rewards` and `steps`. # b) User defined: Four objects in tuple, the first three is same as default using, the last # one `others` can be tuple or single one as user defined. # c) Other situation: Runtime error. ans = self.train_one_episode() loss, rewards, steps, others = [], [], [], [] if len(ans) == 3: loss, rewards, steps = ans elif len(ans) == 4: loss, rewards, steps, others = ans else: raise RuntimeError("The output number of function `train_one_episode` must be 3 or 4, \ and represent for `loss, rewards, steps, [optional]others.` in order") cb_params.loss = loss cb_params.total_rewards = rewards cb_params.steps = steps cb_params.others = others callback_list.episode_end(cb_params) cb_params.cur_episode = i + 1 callback_list.end(cb_params)
[文档] def train_one_episode(self): """ The interface of train one episode function in train. And the output of this function must be constricted as `loss, rewards, steps, [optional]others` in order. """ raise NotImplementedError("Method train_one_episode should be overridden by subclass.\ and the output must be constricted as `loss, rewards, steps, [optional]others` in order")
[文档] def evaluate(self): """ The interface of the evaluate function for evaluate in train. """ raise NotImplementedError("Method evaluate should be overridden by subclass.")
def _load_ckpt(self, ckpt_path=None, name=None, net=None): """load checkpoint""" # 1 Deal with the input file or input path. if os.path.exists(ckpt_path) and os.path.isfile(ckpt_path): ckpt_file = ckpt_path else: dir_path = ckpt_path + '/' + name # 2 If ckpt file does not exist, just return. Otherwise, get the newest one. if not os.path.exists(dir_path): return files = os.listdir(dir_path) if files == []: return ckpt_files = [] for filename in files: if os.path.splitext(filename)[-1] == '.ckpt' and (name in filename): ckpt_files.append(dir_path + '/' + filename) ckpt_file = sorted(ckpt_files, key=os.path.getmtime)[-1] # 3 Load the checkpoint. if os.path.exists(ckpt_file): print("Load file ", ckpt_file) param_dict = load_checkpoint(ckpt_file) _, not_load = load_param_into_net(net, param_dict) if not_load: raise RuntimeError("Load params into net failed!") else: print("Warning: missing ckpt file for ", name) def _init_or_restore(self, ckpt_path=None): '''Init or restore the variables.''' if ckpt_path: self.vars = self.trainable_variables() for key, val in self.vars.items(): self._load_ckpt(ckpt_path, key, val)
[文档] def trainable_variables(self): """ The variables for saving to checkpoint. """ raise NotImplementedError("Method trainable_variables should be overridden by subclass.")
[文档] def load_and_eval(self, ckpt_path=None): """ The interface of the eval function for offline. A checkpoint must be provided. Args: ckpt_path (string): The checkpoint file to restore net. Default: ``None``. """ if ckpt_path is None: raise RuntimeError("Please provide a ckpt_path.") self._init_or_restore(ckpt_path) reward = self.evaluate() reward = reward.asnumpy() print("-----------------------------------------") print(f"Evaluate result is {reward:.3f}, checkpoint file in {ckpt_path}") print("-----------------------------------------")