Per-sample-gradients
计算per-sample-gradients是指计算一个批量样本中每个样本的梯度。在训练神经网络时,很多深度学习框架会计算批量样本的梯度,并利用批量样本的梯度更新网络参数。per-sample-gradients可以帮助我们在训练神经网络时,更准确地计算每个样本对网络参数的影响,从而更好地提高模型的训练效果。
在很多深度学习计算框架中,计算per-sample-gradients是一件很麻烦的事情,因为这些框架会直接累加整个批量样本的梯度。利用这些框架,我们可以想到一个简单的方法来计算per-sample-gradients,即计算批量样本中的每一个样本的预测值和标签值的损失,并计算该损失关于网络参数的梯度,但这个方法显然是很低效的。
MindSpore为我们提供了更高效的方法来计算per-sample-gradients。
我们以TD(0)(Temporal Difference)算法为例对计算per-sample-gradients的高效方法进行说明。TD(0)是一种基于时间差分的强化学习算法,它可以在没有环境模型的情况下学习最优策略。在TD(0)算法中,会根据当前的奖励,对值函数的估计值进行更新,TD(0)算法公式如下,
其中\(V(S_{t})\)是当前的值函数估计值,\(\alpha\)是学习率,\(R_{t+1}\)是在状态\(S_{t}\)下执行动作后获得的奖励,\(\gamma\)是折扣因子,\(V(S_{t+1})\)是下一个状态\(S_{t+1}\)的值函数估计值,\(R_{t+1} + \gamma V(S_{t+1})\)被称为TD目标,\(R_{t+1} + \gamma V(S_{t+1}) - V(S_{t})\)被称为TD偏差。
通过不断地使用TD(0)算法更新值函数估计值,可以逐步学习到最优策略,从而使在环境中获得的奖励最大化。
在MindSpore中,将jit,vmap和grad组合在一起,我们可以得到更高效的方法来计算per-sample-gradients。
下面对该方法进行介绍,假设在状态\(s_{t}\)时的估计值\(v_{\theta}\)由一个线性函数进行参数化。
[1]:
from mindspore import ops, Tensor, vmap, jit, grad
value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])
考虑如下场景,从状态\(s_{t}\)转换到状态\(s_{t+1}\),且在这个过程中,我们观察到的奖励为\(r_{t+1}\)。
[2]:
s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])
参数\({\theta}\)的更新量的计算公式为:
参数\({\theta}\)的更新量并不是任何损失函数的梯度,然而,它可以被认为是下面的伪损失函数的梯度(假设忽略目标值\(r_{t+1} + v_{\theta}(s_{t+1})\)对计算\(L(\theta)\)关于\({\theta}\)的梯度的影响),
计算参数\({\theta}\)的更新量(计算\(L(\theta)\)关于\({\theta}\)的梯度)时,我们需要使用ops.stop_gradient
消除目标值\(r_{t+1} + v_{\theta}(s_{t+1})\)对计算\({\theta}\)梯度的影响,这可以使得在求导过程中,目标值\(r_{t+1} + v_{\theta}(s_{t+1})\)不对\({\theta}\)求导,以得到参数\({\theta}\)正确的更新量。
我们给出伪损失函数\(L(\theta)\)在MindSpore中的实现,
[3]:
def td_loss(theta, s_tm1, r_t, s_t):
v_t = value_fn(theta, s_t)
target = r_tp1 + value_fn(theta, s_tp1)
return (ops.stop_gradient(target) - v_t) ** 2
将td_loss
传入grad
中,计算td_loss关于theta
的梯度,即theta
的更新量。
[4]:
td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)
[-4. -8. -0.]
td_update
仅根据一个样本,计算td_loss关于参数\({\theta}\)的梯度,我们可以使用vmap
对该函数进行矢量化,它会对所有的inputs和outputs添加一个批处理维度。现在,我们给出一批量的输入,并产生一批量的输出,输出批量中的每个输出元素都对应于输入批量中相应的输入元素。
[5]:
batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])
per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8. 0.]
[-4. -8. 0.]]
在上面的例子中,我们需要手动地为per_sample_grads
传递一批量的theta
,但实际上,我们可以仅传入单个的theta
,为了实现这一点,我们对vmap
传入参数in_axes
,在in_axes
中,参数theta
对应的位置被设置为None
,其他参数对应的位置被设置为0
。这使得我们仅需向除theta
以外的其他参数添加一个额外的轴。
[6]:
inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8. 0.]
[-4. -8. 0.]]
到这里,已经可以正确地计算每个样本的梯度了,但是我们还可以让计算过程变得更快些,我们使用jit
调用inefficiecient_per_sample_grads
,这会将inefficiecient_per_sample_grads
编译为一张可调用的MindSpore图,这会提升它的运行效率。
[7]:
efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8. 0.]
[-4. -8. 0.]]