文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

PR

小问题,全程线上修改...

一键搞定!

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

Per-sample-gradients

在线运行下载Notebook下载样例代码查看源文件

计算per-sample-gradients是指计算一个批量样本中每个样本的梯度。在训练神经网络时,很多深度学习框架会计算批量样本的梯度,并利用批量样本的梯度更新网络参数。per-sample-gradients可以帮助我们在训练神经网络时,更准确地计算每个样本对网络参数的影响,从而更好地提高模型的训练效果。

在很多深度学习计算框架中,计算per-sample-gradients是一件很麻烦的事情,因为这些框架会直接累加整个批量样本的梯度。利用这些框架,我们可以想到一个简单的方法来计算per-sample-gradients,即计算批量样本中的每一个样本的预测值和标签值的损失,并计算该损失关于网络参数的梯度,但这个方法显然是很低效的。

MindSpore为我们提供了更高效的方法来计算per-sample-gradients。

我们以TD(0)(Temporal Difference)算法为例对计算per-sample-gradients的高效方法进行说明。TD(0)是一种基于时间差分的强化学习算法,它可以在没有环境模型的情况下学习最优策略。在TD(0)算法中,会根据当前的奖励,对值函数的估计值进行更新,TD(0)算法公式如下,

V(St)=V(St)+α(Rt+1+γV(St+1)V(St))

其中V(St)是当前的值函数估计值,α是学习率,Rt+1是在状态St下执行动作后获得的奖励,γ是折扣因子,V(St+1)是下一个状态St+1的值函数估计值,Rt+1+γV(St+1)被称为TD目标,Rt+1+γV(St+1)V(St)被称为TD偏差。

通过不断地使用TD(0)算法更新值函数估计值,可以逐步学习到最优策略,从而使在环境中获得的奖励最大化。

在MindSpore中,将jit,vmap和grad组合在一起,我们可以得到更高效的方法来计算per-sample-gradients。

下面对该方法进行介绍,假设在状态st时的估计值vθ由一个线性函数进行参数化。

[1]:
from mindspore import ops, Tensor, vmap, jit, grad


value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])

考虑如下场景,从状态st转换到状态st+1,且在这个过程中,我们观察到的奖励为rt+1

[2]:
s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])

参数θ的更新量的计算公式为:

Δθ=(rt+1+vθ(st+1)vθ(st))vθ(st)

参数θ的更新量并不是任何损失函数的梯度,然而,它可以被认为是下面的伪损失函数的梯度(假设忽略目标值rt+1+vθ(st+1)对计算L(θ)关于θ的梯度的影响),

L(θ)=[rt+1+vθ(st+1)vθ(st)]2

计算参数θ的更新量(计算L(θ)关于θ的梯度)时,我们需要使用ops.stop_gradient消除目标值rt+1+vθ(st+1)对计算θ梯度的影响,这可以使得在求导过程中,目标值rt+1+vθ(st+1)不对θ求导,以得到参数θ正确的更新量。

我们给出伪损失函数L(θ)在MindSpore中的实现,

[3]:
def td_loss(theta, s_tm1, r_t, s_t):
    v_t = value_fn(theta, s_t)
    target = r_tp1 + value_fn(theta, s_tp1)
    return (ops.stop_gradient(target) - v_t) ** 2

td_loss传入grad中,计算td_loss关于theta的梯度,即theta的更新量。

[4]:
td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)
[-4. -8. -0.]

td_update仅根据一个样本,计算td_loss关于参数θ的梯度,我们可以使用vmap对该函数进行矢量化,它会对所有的inputs和outputs添加一个批处理维度。现在,我们给出一批量的输入,并产生一批量的输出,输出批量中的每个输出元素都对应于输入批量中相应的输入元素。

[5]:
batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])

per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]

在上面的例子中,我们需要手动地为per_sample_grads传递一批量的theta,但实际上,我们可以仅传入单个的theta,为了实现这一点,我们对vmap传入参数in_axes,在in_axes中,参数theta对应的位置被设置为None,其他参数对应的位置被设置为0。这使得我们仅需向除theta以外的其他参数添加一个额外的轴。

[6]:
inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]

到这里,已经可以正确地计算每个样本的梯度了,但是我们还可以让计算过程变得更快些,我们使用jit调用inefficiecient_per_sample_grads,这会将inefficiecient_per_sample_grads编译为一张可调用的MindSpore图,这会提升它的运行效率。

[7]:
efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]