Per-sample-gradients
Calculating per-sample-gradients means calculating the gradient of each sample in a batch sample. When training a neural network, many deep learning frameworks calculate the gradients of the batch samples and use the gradients of the batch samples to update the network parameters. per-sample-gradients can help us to better improve the training of the model by more accurately calculating the effect of each sample on the network parameters when training the neural network.
Calculating per-sample-gradients is a troblesome business in many deep learning computational frameworks because these frameworks directly accumulate the gradients of the entire batch of samples. Using these frameworks, we can think of a simple way to compute per-sample-gradients, i.e., to compute the loss of the predicted and labeled values for each of the batch samples and to compute the gradient of that loss with respect to the network parameters, but this method is clearly very inefficient.
MindSpore provides us with a more efficient way to calculate per-sample-gradients.
We illustrate the efficient method of computing per-sample-gradients with the example of TD(0) (Temporal Difference) algorithm, which is a reinforcement learning algorithm based on temporal difference that learns the optimal strategy in the absence of an environment model. In the TD(0) algorithm, the valued function estimates are updated according to the current rewards. The TD(0) algorithm is formulated as follows:
where
By continuously updating the valued function estimates using the TD(0) algorithm, the optimal policy can be learned incrementally to maximize the reward gained in the environment.
Combining jit, vmap and grad in MindSpore, we get a more efficient way to compute per-sample-gradients.
The method is described below, assuming that the estimate
from mindspore import ops, Tensor, vmap, jit, grad
value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])
Consider the following scenario, transforming from state
s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])
The updating volume of the parameter
The update of the parameter
When computing the update of the parameter ops.stop_gradient
, which can be made such that the target values
We give the implementation of the pseudo-loss function
def td_loss(theta, s_tm1, r_t, s_t):
v_t = value_fn(theta, s_t)
target = r_tp1 + value_fn(theta, s_tp1)
return (ops.stop_gradient(target) - v_t) ** 2
Pass td_loss
into grad
and compute the gradient of td_loss with respect to theta
, i.e., the update of theta
.
td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)
[-4. -8. -0.]
td_update
computes the gradient of td_loss with respect to the parameter vmap
which will add a batch dimension to all inputs and outputs. Now, we give a batch of inputs and produce a batch of outputs, with each output element in the output batch corresponding to the corresponding input element in the input batch.
batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])
per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8. 0.]
[-4. -8. 0.]]
In the above example, we need to manually pass a batch of theta
for per_sample_grads
, but in reality, we can pass just a single theta
. To complete this, we pass the parameter in_axes
to vmap
, where the position corresponding to the parameter theta
in in_axes
is set to None
and the positions corresponding to the other parameters are set to 0
. This allows us to add an additional axis only to parameters other than theta
.
inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8. 0.]
[-4. -8. 0.]]
Up to this point, the gradient for each sample is calculated correctly, but we can also make the calculation process a bit faster. We call inefficiecient_per_sample_grads
using jit
, which will compile inefficiecient_per_sample_grads
into a callable MindSpore graph and improve the efficiency of its operation.
efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8. 0.]
[-4. -8. 0.]]