- class mindspore_rl.utils.SoftUpdate(factor, update_interval, behavior_params, target_params)[源代码]
采用滑动凭据方式更新目标网络的参数。
设目标网络参数为 \(target\_param\),行为网络参数为 \(behavior\_param\), 滑动平均系数为 \(factor\)。 则 \(target\_param = (1. - factor) * behavior\_param + factor * target\_param\)。
- 参数:
factor (float) - 滑动平均系数,范围[0, 1]。
update_interval (int) - 目标网络参数更新间隔。
behavior_params (list) - 行为网络参数列表。
target_params (list) - 目标网络参数列表。
样例:
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore.common.parameter import ParameterTuple >>> from mindspore_rl.utils import SoftUpdate >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() >>> self.behavior_params = ParameterTuple(nn.Dense(10, 20).trainable_params()) >>> self.target_params = ParameterTuple(nn.Dense(10, 20).trainable_params()) >>> self.updater = SoftUpdate(0.9, 2, self.behavior_params, self.target_params) >>> def construct(self): >>> return self.updater() >>> net = Net() >>> for _ in range(10): >>> net() >>> np.allclose(net.behavior_params[0].asnumpy(), net.target_params[0].asnumpy(), atol=1e-5) True