mindspore_rl.utils.utils 源代码

# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Utils.
"""
import os
from importlib import import_module
import yaml
import numpy as np
from typing import Tuple, Union


def _update_dict(dest, src) -> None:
    """update config dict"""
    if src is not None:
        for key in src:
            if key in dest.keys() and isinstance(dest.get(key), dict):
                if isinstance(src.get(key), dict):
                    for v in src.get(key):
                        if isinstance(src.get(key).get(v), dict) and v in dest.get(key) and \
                                isinstance(dest.get(key).get(v), dict):
                            _update_dict(dest[key], src[key])
                        elif isinstance(dest.get(key).get(v), dict):
                            dest[key][v].update(src.get(key).get(v))
                        else:
                            dest[key][v] = src.get(key).get(v)
                else:
                    dest[key].update(src[key])
            else:
                dest[key] = src[key]


[文档]def update_config(config, env_yaml, algo_yaml) -> None: r''' Update the config by the provided yamls. Eg: see `mindspore_rl/algorithm/dqn/config.py`, `mindspore_rl/example/env_yaml/` and `mindspore_rl/example/algo_yaml/` for usage. Args: config (dict): the config to be update. env_yaml (str): the environment yaml file. algo_yaml (str): the algorithm yaml file. ''' if env_yaml: if os.path.exists(env_yaml): with open(env_yaml) as f: data = yaml.safe_load(f) config.collect_env_params['name'] = data.get('env') config.eval_env_params['name'] = data.get('env') _update_dict(config.collect_env_params, data.get('collect_env_params')) _update_dict(config.eval_env_params, data.get('eval_env_params')) if data.get('env_class') and data.get('env_type'): try: env_class = import_module(data.get('env_class')) env_type = getattr(env_class, data.get('env_type')) config.algorithm_config['collect_environment']['type'] = env_type config.algorithm_config['eval_environment']['type'] = env_type except: raise ValueError(f"Import {data.get('env_class')} failed") else: print(f"File {env_yaml} is not exists.") return if algo_yaml: if os.path.exists(algo_yaml): with open(algo_yaml) as f: data = yaml.safe_load(f) if data.get('algorithm_config'): _update_dict(config.algorithm_config, data.get('algorithm_config')) if data.get('policy_params'): _update_dict(config.policy_params, data.get('policy_params')) if data.get('trainer_params'): _update_dict(config.trainer_params, data.get('trainer_params')) if data.get('learner_params'): _update_dict(config.learner_params, data.get('learner_params')) if data.get('learner_class') and data.get('learner_type'): try: learner_class = import_module(data.get('learner_class')) learner_type = getattr(learner_class, data.get('learner_type')) config.algorithm_config['learner']['type'] = learner_type except: raise ValueError(f"Import {data.get('learner_class')} failed") else: print(f"File {algo_yaml} is not exiddsts.") return
def check_type(input_type, items, debug_str): if isinstance(items, list) or isinstance(items, tuple): [check_type(input_type, item, debug_str) for item in items] return items, len(items) elif isinstance(items, input_type): return items, 1 else: raise TypeError(f"input item {debug_str} expects {input_type}, but got {type(items)}") def check_valid_return_value(return_value: Union[Tuple, np.ndarray], debug_str: str) -> int: num_valid_output = 0 if isinstance(return_value, tuple): list_valid_output = [check_valid_return_value(item, debug_str) for item in return_value] num_valid_output = sum(list_valid_output) elif isinstance(return_value, np.ndarray): num_valid_output = 1 else: raise TypeError(f"For {debug_str}, its output must be tuple or np.ndarray, but got {type(return_value)}") return num_valid_output