Per-sample-gradients

在线运行下载Notebook下载样例代码查看源文件

计算per-sample-gradients是指计算一个批量样本中每个样本的梯度。在训练神经网络时,很多深度学习框架会计算批量样本的梯度,并利用批量样本的梯度更新网络参数。per-sample-gradients可以帮助我们在训练神经网络时,更准确地计算每个样本对网络参数的影响,从而更好地提高模型的训练效果。

在很多深度学习计算框架中,计算per-sample-gradients是一件很麻烦的事情,因为这些框架会直接累加整个批量样本的梯度。利用这些框架,我们可以想到一个简单的方法来计算per-sample-gradients,即计算批量样本中的每一个样本的预测值和标签值的损失,并计算该损失关于网络参数的梯度,但这个方法显然是很低效的。

MindSpore为我们提供了更高效的方法来计算per-sample-gradients。

我们以TD(0)(Temporal Difference)算法为例对计算per-sample-gradients的高效方法进行说明。TD(0)是一种基于时间差分的强化学习算法,它可以在没有环境模型的情况下学习最优策略。在TD(0)算法中,会根据当前的奖励,对值函数的估计值进行更新,TD(0)算法公式如下,

\[V(S_{t}) = V(S_{t}) + \alpha (R_{t+1} + \gamma V(S_{t+1}) - V(S_{t}))\]

其中\(V(S_{t})\)是当前的值函数估计值,\(\alpha\)是学习率,\(R_{t+1}\)是在状态\(S_{t}\)下执行动作后获得的奖励,\(\gamma\)是折扣因子,\(V(S_{t+1})\)是下一个状态\(S_{t+1}\)的值函数估计值,\(R_{t+1} + \gamma V(S_{t+1})\)被称为TD目标,\(R_{t+1} + \gamma V(S_{t+1}) - V(S_{t})\)被称为TD偏差。

通过不断地使用TD(0)算法更新值函数估计值,可以逐步学习到最优策略,从而使在环境中获得的奖励最大化。

在MindSpore中,将jit,vmap和grad组合在一起,我们可以得到更高效的方法来计算per-sample-gradients。

下面对该方法进行介绍,假设在状态\(s_{t}\)时的估计值\(v_{\theta}\)由一个线性函数进行参数化。

[1]:
from mindspore import ops, Tensor, vmap, jit, grad


value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])

考虑如下场景,从状态\(s_{t}\)转换到状态\(s_{t+1}\),且在这个过程中,我们观察到的奖励为\(r_{t+1}\)

[2]:
s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])

参数\({\theta}\)的更新量的计算公式为:

\[\Delta{\theta}=(r_{t+1} + v_{\theta}(s_{t+1}) - v_{\theta}(s_{t}))\nabla v_{\theta}(s_{t})\]

参数\({\theta}\)的更新量并不是任何损失函数的梯度,然而,它可以被认为是下面的伪损失函数的梯度(假设忽略目标值\(r_{t+1} + v_{\theta}(s_{t+1})\)对计算\(L(\theta)\)关于\({\theta}\)的梯度的影响),

\[L(\theta) = [r_{t+1} + v_{\theta}(s_{t+1}) - v_{\theta}(s_{t})]^{2}\]

计算参数\({\theta}\)的更新量(计算\(L(\theta)\)关于\({\theta}\)的梯度)时,我们需要使用ops.stop_gradient消除目标值\(r_{t+1} + v_{\theta}(s_{t+1})\)对计算\({\theta}\)梯度的影响,这可以使得在求导过程中,目标值\(r_{t+1} + v_{\theta}(s_{t+1})\)不对\({\theta}\)求导,以得到参数\({\theta}\)正确的更新量。

我们给出伪损失函数\(L(\theta)\)在MindSpore中的实现,

[3]:
def td_loss(theta, s_tm1, r_t, s_t):
    v_t = value_fn(theta, s_t)
    target = r_tp1 + value_fn(theta, s_tp1)
    return (ops.stop_gradient(target) - v_t) ** 2

td_loss传入grad中,计算td_loss关于theta的梯度,即theta的更新量。

[4]:
td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)
[-4. -8. -0.]

td_update仅根据一个样本,计算td_loss关于参数\({\theta}\)的梯度,我们可以使用vmap对该函数进行矢量化,它会对所有的inputs和outputs添加一个批处理维度。现在,我们给出一批量的输入,并产生一批量的输出,输出批量中的每个输出元素都对应于输入批量中相应的输入元素。

[5]:
batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])

per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]

在上面的例子中,我们需要手动地为per_sample_grads传递一批量的theta,但实际上,我们可以仅传入单个的theta,为了实现这一点,我们对vmap传入参数in_axes,在in_axes中,参数theta对应的位置被设置为None,其他参数对应的位置被设置为0。这使得我们仅需向除theta以外的其他参数添加一个额外的轴。

[6]:
inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]

到这里,已经可以正确地计算每个样本的梯度了,但是我们还可以让计算过程变得更快些,我们使用jit调用inefficiecient_per_sample_grads,这会将inefficiecient_per_sample_grads编译为一张可调用的MindSpore图,这会提升它的运行效率。

[7]:
efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]