Recomputation

View Source On Gitee

Overview

The automatic differential of MindSpore is in reverse-mode, which derives the backward pass according to the forward pass. Before some backward operators are computed, the results of some forward operators should be ready. It leads to the problem that the memory occupied by these results of the forward operators, can not be reused until the computation of the backward operators are completed. This problem can drive up the peak of memory, which is particularly significant in the large model.

In order to solve this problem, Mindspore provides the recomputation function. It will recompute the forward operators before computing the backward operators rather than storing the results of forward operators, which can help the memory be reused. This tutorial takes the model ResNet-50 for example to explain how to configure recomputation to train your model in MindSpore.

Related interfaces:

  1. mindspore.nn.Cell.recompute(): Call the recompute interface. After calling this interface, when computing the reverse part, all the operators inside the Cell and all the operators inside the sub-Cells are recomputed, except for the output operator of that Cell. Both PyNative mode and Graph mode are supported.

  2. mindspore.ops.Primitive.recompute(): Call the recompute interface of Primitive. After calling this interface, the operator is recomputed when computing the reverse part. Only Graph mode is supported.

  3. mindspore.recompute(block, *args, **kwargs): Call the recompute interface. After calling this function interface, the block will be recomputed. Only PyNative mode is supported.

Basic Principle

MindSpore automatically derives the reverse graph according to the forward graph compute process, and the forward graph and the inverse graph together form a complete compute graph. When calculating some reverse operators, it may be necessary to use the compute results of some forward operators, resulting in the compute results of these forward operators, which need to reside in memory until these reverse operators are computed, and the memory they occupy will not be reused by other operators. The computational results of these forward operators, which reside in memory for a long time, push up the peak memory footprint of the computation, especially in large-scale network models.

In order to reduce memory peaks, the recompute technique can not save the compute results of the forward activation layer, so that the memory can be reused, and then when calculating the reverse part, recompute the results of the forward activation layer. MindSpore provides the ability to recompute.

The recompute function is implemented as a forward operator that is recomputed according to the user's specified needs, copies the same operator, outputs it to the reverse operator, and deletes the continuous edge relationship between the original forward operator and the reverse operator. In addition, we need to ensure that the copied operator only begins to be evaluated when the corresponding inverse part is computed, so we need to insert control dependencies to ensure the order in which the operators are executed. As shown in the following figure:

Figure: Forward and reverse diagram before and after the recompute function is enabled

For user convenience, MindSpore currently provides not only a recompute interface for individual operators, but also a recompute interface for Cell. When the user calls The Cell's recompute interface, all forward operators in the Cell are set to recompute.

Taking the GPT-3 model as an example, the policy is set to recalculate the cell corresponding to the layerer for each layer, and then the output operator of the layerer is set to non-recompute. The effect of recompute on the 72-layer GPT-3 network is shown in the following figure:

Figure: Comparison of GPT-3 memory usage before and after recalculation function is enabled

Operation Practice

The following is an illustration of the recomputation operation using an Ascend as an example:

Sample Code Description

Download the complete sample code: recompute.

The directory structure is as follows:

└─ sample_code
    ├─ recompute
       └── example.py
    ...

example.py is the script that defines the network structure and execution flow. In order to emphasize the difference between before and after recomputation, recomputation is not enabled by default in this sample. If you need to enable it, please refer to the following configuration.

Network Definition

The network Net is formed by connecting the 10 sub-networks Block in nn.CellList in sequence, and Grad is used to derive Net to get the derivative with respect to the inputs of the network.

import numpy as np
from mindspore.nn import Cell
from mindspore.common import Tensor, Parameter
from mindspore import ops, nn


class Block(Cell):
    def __init__(self):
        super(Block, self).__init__()
        self.transpose1 = ops.Transpose()
        self.transpose2 = ops.Transpose()
        self.transpose3 = ops.Transpose()
        self.transpose4 = ops.Transpose()
        self.real_div1 = ops.RealDiv()
        self.real_div2 = ops.RealDiv()
        self.batch_matmul1 = ops.BatchMatMul()
        self.batch_matmul2 = ops.BatchMatMul()
        self.add = ops.Add()
        self.softmax = ops.Softmax(-1)
        self.dropout = ops.Dropout(0.9)
        self.expand_dims = ops.ExpandDims()
        self.sub = ops.Sub()
        self.mul = ops.Mul()
        self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32)))

    def construct(self, x):
        transpose1 = self.transpose1(x, (0, 2, 1, 3))
        real_div1 = self.real_div1(transpose1, Tensor(2.37891))
        transpose2 = self.transpose2(x, (0, 2, 3, 1))
        real_div2 = self.real_div2(transpose2, Tensor(2.37891))
        batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
        expand_dims = self.expand_dims(self.y, 1)
        sub = self.sub(Tensor([1.0]), expand_dims)
        mul = self.mul(sub, Tensor([-0.0001]))
        add = self.add(mul, batch_matmul1)
        soft_max = self.softmax(add)
        dropout = self.dropout(soft_max)
        transpose3 = self.transpose3(x, (0, 2, 1, 3))
        batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
        transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
        return transpose4


class Net(Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.blocks = nn.CellList()
        for _ in range(10):
            b = Block()
            self.blocks.append(b)

    def construct(self, x):
        out = x
        for i in range(10):
            out = self.blocks[i](out)
        return out


class Grad(Cell):
    def __init__(self, net):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation()
        self.net = net

    def construct(self, x):
        grad_net = self.grad(self.net)
        return grad_net(x)

Executing the Network

In this step, we need to define the network inputs and then call Grad to get the derivatives, the code is as follows:

import numpy as np
from mindspore.common import Tensor

input_x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
network = Net()
grad_network = Grad(network)
output = grad_network(input_x)
print(output)

Running Script

The next command calls the corresponding script as follows:

GLOG_v=1 python example.py

With the GLOG_v=1 command, we can print out the INFO level logs to see the network execution memory footprint as follows:

Device MOC memory size: 32768M
MindSpore Used memory size: 30682M
MindSpore memory base address: 0x12c100000000
Total Static Memory size: 1032M
Total Dynamic memory size: 167M
Actual peak memory usage: 1199M
Dynamic memory size of this graph: 0M

You can see that the size of the dynamic memory footprint for executing this network is 167 MB. If we set the environment variable export MS_DEV_SAVE_GRAPHS=1 before executing the script, you can see that the xx_validate_xxx.ir file is generated in the directory where the script is executed. Open the xx_validate_xxx.ir file as follows, we can see that the calculation result of node %38 is needed for the calculation of node %41 (forward propagation operator) and node %291 (backward propagation operator), so the memory occupied by the calculation result of node %38 needs to wait until the calculation of node %291 is completed before it is released (here the % followed by the sequence number is related to the operator execution sequence). The reason for the long memory footprint of the node %38 computation result is that the order of the backpropagation is reversed from the forward propagation, and the backpropagation function corresponding to the first Block inside the 10 Blocks in the forward propagation is the last to be executed instead.

%38(equiv_11_real_div1) = PrimFunc_RealDiv(%37, Tensor(shape=[], dtype=Float32, value=2.37891)) {instance name: real_div2} cnode_attrs: {checkpoint: Bool(1)} cnode_primal_attrs: {unique_id: "10058"}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (), value=...>) -> (<Tensor[Float32], (8, 16, 128, 32)>)
      # Fullname with scope: (Default/net-Net/blocks-CellList/0-Block/RealDiv-op0)
...
%41(equiv_8_batch_matmul1) = PrimFunc_BatchMatMul(%38, %40, Bool(0), Bool(0)) cnode_attrs: {checkpoint: Bool(1)} cnode_primal_attrs: {unique_id: "10055"}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (8, 16, 32, 128)>, <Bool, NoShape>, <Bool, NoShape>) -> (<Tensor[Float32], (8, 16, 128, 128)>)
      # Fullname with scope: (Default/net-Net/blocks-CellList/0-Block/BatchMatMul-op0)
...
%291(CNode_401) = PrimFunc_BatchMatMul(%38, %287, Bool(1), Bool(0)) cnode_attrs: {checkpoint: Bool(1)} cnode_primal_attrs: {forward_node_name: "BatchMatMul_10055", forward_unique_id: "10055"}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (8, 16, 128, 128)>, <Bool, NoShape>, <Bool, NoShape>) -> (<Tensor[Float32], (8, 16, 32, 128)>)
      # Fullname with scope: (Gradients/Default/net-Net/blocks-CellList/0-Block/Grad_BatchMatMul/BatchMatMul-op38)

If we do recomputation of the first Block, we can make the first Block to be released immediately after the forward part of the calculation is finished, and go for recomputation only at the time of reverse propagation, thus we can significantly shorten the time of memory occupancy and reduce the memory spikes. The code for using recomputation is as follows:

class Net(Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.blocks = nn.CellList()
        for _ in range(10):
            b = Block()
            # Call the recompute interface on each Block to turn on recomputation
            b.recompute()
            self.blocks.append(b)

    def construct(self, x):
        out = x
        for i in range(10):
            out = self.blocks[i](out)
        return out

After using recomputation, we then run the script as follows:

GLOG_v=1 python example.py

Checking the network execution memory footprint size again, as shown below, the dynamic memory footprint for executing this network is reduced to 69 MB.

Device MOC memory size: 32768M
MindSpore Used memory size: 30680M
MindSpore memory base address: 0x12c100000000
Total Static Memory size: 1032M
Total Dynamic memory size: 69M
Actual peak memory usage: 1101M
Dynamic memory size of this graph: 0M

Open the xx_validate_xxx.ir file again, as shown below, and you can see that the first input to the backpropagation node %429 is node %416, which is copied based on the forward propagation node %38, and the memory occupied by the computation results of node %38 can be released after the computation of node %41, thus improving the memory reuse rate.

%38(equiv_93_real_div1) = PrimFunc_RealDiv(%37, Tensor(shape=[], dtype=Float32, value=2.37891)) {instance name: real_div2} cnode_attrs: {recompute_sub_graph: U64(0), recompute_id: I64(5), recompute: Bool(1), need_cse_after_recompute: Bool(1)} cnode_primal_attrs: {unique_id: "13860"}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (), value=...>) -> (<Tensor[Float32], (8, 16, 128, 32)>)
      # Fullname with scope: (recompute_Default/net-Net/blocks-CellList/0-Block/RealDiv-op0)
...
%41(equiv_90_batch_matmul1) = PrimFunc_BatchMatMul(%38, %40, Bool(0), Bool(0)) {instance name: batch_matmul2} cnode_attrs: {recompute_sub_graph: U64(0), recompute_id: I64(8), recompute: Bool(1), need_cse_after_recompute: Bool(1)} cnode_primal_attrs: {unique_id: "13857"}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (8, 16, 32, 128)>, <Bool, NoShape>, <Bool, NoShape>) -> (<Tensor[Float32], (8, 16, 128, 128)>)
      # Fullname with scope: (recompute_Default/net-Net/blocks-CellList/0-Block/BatchMatMul-op0)
...
%416(CNode_1062) = PrimFunc_RealDiv(%410, Tensor(shape=[], dtype=Float32, value=2.37891)) {instance name: real_div2} cnode_attrs: {recompute_sub_graph: U64(0), recompute_id: I64(5), duplicated: Bool(1), need_cse_after_recompute: Bool(1)}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (), value=...>) -> (<Tensor[Float32], (8, 16, 128, 32)>)
      # Fullname with scope: (recompute_Default/net-Net/blocks-CellList/0-Block/RealDiv-op2)
...
%429(CNode_1075) = PrimFunc_BatchMatMul(%416, %425, Bool(1), Bool(0)) {instance name: batch_matmul2} cnode_attrs: {recompute_sub_graph: U64(0), target_grad: Bool(1), checkpoint: Bool(1)} cnode_primal_attrs: {forward_node_name: "BatchMatMul_13857", forward_unique_id: "13857"}
      : (<Tensor[Float32], (8, 16, 128, 32)>, <Tensor[Float32], (8, 16, 128, 128)>, <Bool, NoShape>, <Bool, NoShape>) -> (<Tensor[Float32], (8, 16, 32, 128)>)
      # Fullname with scope: (Gradients/recompute_Default/net-Net/blocks-CellList/0-Block/Grad_BatchMatMul/BatchMatMul-op38)