Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

PR

Just a small problem.

I can fix it online!

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

Per-sample-gradients

View Source On Gitee

Calculating per-sample-gradients means calculating the gradient of each sample in a batch sample. When training a neural network, many deep learning frameworks calculate the gradients of the batch samples and use the gradients of the batch samples to update the network parameters. per-sample-gradients can help us to better improve the training of the model by more accurately calculating the effect of each sample on the network parameters when training the neural network.

Calculating per-sample-gradients is a troblesome business in many deep learning computational frameworks because these frameworks directly accumulate the gradients of the entire batch of samples. Using these frameworks, we can think of a simple way to compute per-sample-gradients, i.e., to compute the loss of the predicted and labeled values for each of the batch samples and to compute the gradient of that loss with respect to the network parameters, but this method is clearly very inefficient.

MindSpore provides us with a more efficient way to calculate per-sample-gradients.

We illustrate the efficient method of computing per-sample-gradients with the example of TD(0) (Temporal Difference) algorithm, which is a reinforcement learning algorithm based on temporal difference that learns the optimal strategy in the absence of an environment model. In the TD(0) algorithm, the valued function estimates are updated according to the current rewards. The TD(0) algorithm is formulated as follows:

V(St)=V(St)+α(Rt+1+γV(St+1)V(St))

where V(St) is the current valued function estimate, α is the learning rate, Rt+1 is the reward obtained after performing the action in state St, γ is the discount factor, V(St+1) is the valued function estimate for the next state St+1, Rt+1+ gammaV(St+1) is known as the TD target, and Rt+1+γV(St+1)V(St) is known as the TD bias.

By continuously updating the valued function estimates using the TD(0) algorithm, the optimal policy can be learned incrementally to maximize the reward gained in the environment.

Combining jit, vmap and grad in MindSpore, we get a more efficient way to compute per-sample-gradients.

The method is described below, assuming that the estimate vθ at state st is parameterized by a linear function.

from mindspore import ops, Tensor, vmap, jit, grad


value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])

Consider the following scenario, transforming from state st to state st+1 and in which we observe a reward of rt+1.

s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])

The updating volume of the parameter θ is given by:

Δθ=(rt+1+vθ(st+1)vθ(st))vθ(st)

The update of the parameter θ is not the gradient of any loss function, however, it can be considered as the gradient of the following pseudo-loss function (assuming that the effect of the target value rt+1+vθ(st+1) on the computation of the gradient of L(θ) with respect to θ is ignored).

L(θ)=[rt+1+vθ(st+1)vθ(st)]2

When computing the update of the parameter θ (computing the gradient of L(θ) with respect to θ), we need to eliminate the effect of the target value rt+1+vθ(st+1) on the computation of the gradient of θ using ops.stop_gradient, which can be made such that the target values rt+1+vθ(st+1) do not contribute to the derivation of θ during the derivation process to obtain the correct update of the parameter θ.

We give the implementation of the pseudo-loss function L(θ) in MindSpore.

def td_loss(theta, s_tm1, r_t, s_t):
    v_t = value_fn(theta, s_t)
    target = r_tp1 + value_fn(theta, s_tp1)
    return (ops.stop_gradient(target) - v_t) ** 2

Pass td_loss into grad and compute the gradient of td_loss with respect to theta, i.e., the update of theta.

td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)
[-4. -8. -0.]

td_update computes the gradient of td_loss with respect to the parameter θ based on only one sample. We can vectorize this function using vmap which will add a batch dimension to all inputs and outputs. Now, we give a batch of inputs and produce a batch of outputs, with each output element in the output batch corresponding to the corresponding input element in the input batch.

batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])

per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]

In the above example, we need to manually pass a batch of theta for per_sample_grads, but in reality, we can pass just a single theta. To complete this, we pass the parameter in_axes to vmap, where the position corresponding to the parameter theta in in_axes is set to None and the positions corresponding to the other parameters are set to 0. This allows us to add an additional axis only to parameters other than theta.

inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]

Up to this point, the gradient for each sample is calculated correctly, but we can also make the calculation process a bit faster. We call inefficiecient_per_sample_grads using jit, which will compile inefficiecient_per_sample_grads into a callable MindSpore graph and improve the efficiency of its operation.

efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
[[-4. -8.  0.]
 [-4. -8.  0.]]