class mindspore_rl.utils.SoftUpdate(factor, update_interval, behavior_params, target_params)[源代码]

采用滑动凭据方式更新目标网络的参数。

设目标网络参数为 \(target\_param\),行为网络参数为 \(behavior\_param\), 滑动平均系数为 \(factor\)。 则 \(target\_param = (1. - factor) * behavior\_param + factor * target\_param\)

参数:
  • factor (float) - 滑动平均系数,范围[0, 1]。

  • update_interval (int) - 目标网络参数更新间隔。

  • behavior_params (list) - 行为网络参数列表。

  • target_params (list) - 目标网络参数列表。

样例:

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.common.parameter import ParameterTuple
>>> from mindspore_rl.utils import SoftUpdate
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.behavior_params = ParameterTuple(nn.Dense(10, 20).trainable_params())
>>>         self.target_params = ParameterTuple(nn.Dense(10, 20).trainable_params())
>>>         self.updater = SoftUpdate(0.9, 2, self.behavior_params, self.target_params)
>>>     def construct(self):
>>>         return self.updater()
>>> net = Net()
>>> for _ in range(10):
>>>     net()
>>> np.allclose(net.behavior_params[0].asnumpy(), net.target_params[0].asnumpy(), atol=1e-5)
True