# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Utils.
"""
import os
from importlib import import_module
import yaml
import numpy as np
from typing import Tuple, Union
def _update_dict(dest, src) -> None:
"""update config dict"""
if src is not None:
for key in src:
if key in dest.keys() and isinstance(dest.get(key), dict):
if isinstance(src.get(key), dict):
for v in src.get(key):
if isinstance(src.get(key).get(v), dict) and v in dest.get(key) and \
isinstance(dest.get(key).get(v), dict):
_update_dict(dest[key], src[key])
elif isinstance(dest.get(key).get(v), dict):
dest[key][v].update(src.get(key).get(v))
else:
dest[key][v] = src.get(key).get(v)
else:
dest[key].update(src[key])
else:
dest[key] = src[key]
[文档]def update_config(config, env_yaml, algo_yaml) -> None:
r'''
Update the config by the provided yamls. Eg: see `mindspore_rl/algorithm/dqn/config.py`,
`mindspore_rl/example/env_yaml/` and `mindspore_rl/example/algo_yaml/` for usage.
Args:
config (dict): the config to be update.
env_yaml (str): the environment yaml file.
algo_yaml (str): the algorithm yaml file.
'''
if env_yaml:
if os.path.exists(env_yaml):
with open(env_yaml) as f:
data = yaml.safe_load(f)
config.collect_env_params['name'] = data.get('env')
config.eval_env_params['name'] = data.get('env')
_update_dict(config.collect_env_params, data.get('collect_env_params'))
_update_dict(config.eval_env_params, data.get('eval_env_params'))
if data.get('env_class') and data.get('env_type'):
try:
env_class = import_module(data.get('env_class'))
env_type = getattr(env_class, data.get('env_type'))
config.algorithm_config['collect_environment']['type'] = env_type
config.algorithm_config['eval_environment']['type'] = env_type
except:
raise ValueError(f"Import {data.get('env_class')} failed")
else:
print(f"File {env_yaml} is not exists.")
return
if algo_yaml:
if os.path.exists(algo_yaml):
with open(algo_yaml) as f:
data = yaml.safe_load(f)
if data.get('algorithm_config'):
_update_dict(config.algorithm_config, data.get('algorithm_config'))
if data.get('policy_params'):
_update_dict(config.policy_params, data.get('policy_params'))
if data.get('trainer_params'):
_update_dict(config.trainer_params, data.get('trainer_params'))
if data.get('learner_params'):
_update_dict(config.learner_params, data.get('learner_params'))
if data.get('learner_class') and data.get('learner_type'):
try:
learner_class = import_module(data.get('learner_class'))
learner_type = getattr(learner_class, data.get('learner_type'))
config.algorithm_config['learner']['type'] = learner_type
except:
raise ValueError(f"Import {data.get('learner_class')} failed")
else:
print(f"File {algo_yaml} is not exiddsts.")
return
def check_type(input_type, items, debug_str):
if isinstance(items, list) or isinstance(items, tuple):
[check_type(input_type, item, debug_str) for item in items]
return items, len(items)
elif isinstance(items, input_type):
return items, 1
else:
raise TypeError(f"input item {debug_str} expects {input_type}, but got {type(items)}")
def check_valid_return_value(return_value: Union[Tuple, np.ndarray], debug_str: str) -> int:
num_valid_output = 0
if isinstance(return_value, tuple):
list_valid_output = [check_valid_return_value(item, debug_str) for item in return_value]
num_valid_output = sum(list_valid_output)
elif isinstance(return_value, np.ndarray):
num_valid_output = 1
else:
raise TypeError(f"For {debug_str}, its output must be tuple or np.ndarray, but got {type(return_value)}")
return num_valid_output