Recomputation
Overview
The automatic differential of MindSpore is in reverse-mode, which derives the backward pass according to the forward pass. Before some backward operators are computed, the results of some forward operators should be ready. It leads to the problem that the memory occupied by these results of the forward operators, can not be reused until the computation of the backward operators are completed. This problem can drive up the peak of memory, which is particularly significant in the large model.
In order to solve this problem, Mindspore provides the recomputation function. It will recompute the forward operators before computing the backward operators rather than storing the results of forward operators, which can help the memory be reused. This tutorial takes the model ResNet-50 for example to explain how to configure recomputation to train your model in MindSpore.
Related interfaces:
mindspore.nn.Cell.recompute()
: Call the recompute interface. After calling this interface, when computing the reverse part, all the operators inside the Cell and all the operators inside the sub-Cells are recomputed, except for the output operator of that Cell. Both PyNative mode and Graph mode are supported.mindspore.ops.Primitive.recompute()
: Call the recompute interface ofPrimitive
. After calling this interface, the operator is recomputed when computing the reverse part. Only Graph mode is supported.mindspore.recompute(block, *args, **kwargs)
: Call the recompute interface. After calling this function interface, the block will be recomputed. Only PyNative mode is supported.
Basic Principle
MindSpore automatically derives the reverse graph according to the forward graph compute process, and the forward graph and the inverse graph together form a complete compute graph. When calculating some reverse operators, it may be necessary to use the compute results of some forward operators, resulting in the compute results of these forward operators, which need to reside in memory until these reverse operators are computed, and the memory they occupy will not be reused by other operators. The computational results of these forward operators, which reside in memory for a long time, push up the peak memory footprint of the computation, especially in large-scale network models.
In order to reduce memory peaks, the recompute technique can not save the compute results of the forward activation layer, so that the memory can be reused, and then when calculating the reverse part, recompute the results of the forward activation layer. MindSpore provides the ability to recompute.
The recompute function is implemented as a forward operator that is recomputed according to the user's specified needs, copies the same operator, outputs it to the reverse operator, and deletes the continuous edge relationship between the original forward operator and the reverse operator. In addition, we need to ensure that the copied operator only begins to be evaluated when the corresponding inverse part is computed, so we need to insert control dependencies to ensure the order in which the operators are executed. As shown in the following figure:
Figure: Forward and reverse diagram before and after the recompute function is enabled
For user convenience, MindSpore currently provides not only a recompute interface for individual operators, but also a recompute interface for Cell. When the user calls The Cell's recompute interface, all forward operators in the Cell are set to recompute.
Taking the GPT-3 model as an example, the policy is set to recalculate the cell corresponding to the layerer for each layer, and then the output operator of the layerer is set to non-recompute. The effect of recompute on the 72-layer GPT-3 network is shown in the following figure:
Figure: Comparison of GPT-3 memory usage before and after recalculation function is enabled
Operation Practice
The following is an illustration of the recomputation operation using an Ascend or GPU stand-alone 8-card as an example:
Sample Code Description
Download the complete sample code: recompute.
The directory structure is as follows:
└─ sample_code
├─ recompute
├── train.py
└── run.sh
...
train.py
is the script that defines the network structure and inference. run.sh
is the execution script.
Configuring a Distributed Environment
Specify the run mode, run device, run card number via the context interface. The parallel mode is data parallel and HCCL or NCCL communication is initialized by init. Setting the environment variable MS_DEV_SAVE_GRAPHS
to 2 to prints out the computational graph structure for comparison. device_target
is automatically specified as the backend hardware device corresponding to the MindSpore package.
import os
import mindspore as ms
from mindspore.communication import init
os.environ['MS_DEV_SAVE_GRAPHS'] = '2'
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
init()
ms.set_seed(1)
Loading the Dataset
Here the dataset is loaded in data parallel mode, specifying the num_shards
and shard_id
parameters, corresponding to the number of cards and the logical serial number, respectively, with the following code:
import os
import mindspore.dataset as ds
from mindspore import nn
def create_dataset(batch_size):
dataset_path = os.getenv("DATA_PATH")
rank_id = get_rank()
rank_size = get_group_size()
dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id)
image_transforms = [
ds.vision.Rescale(1.0 / 255.0, 0),
ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
ds.vision.HWC2CHW()
]
label_transform = ds.transforms.TypeCast(ms.int32)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset
data_set = create_dataset(32)
Network Definition
The network configures the activation function operator with recomputation based on the single-card model to reduce the memory footprint:
from mindspore import nn, ops
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.layer1 = nn.Dense(28*28, 512)
self.relu1 = ops.ReLU()
self.layer2 = nn.Dense(512, 512)
self.relu2 = ops.ReLU()
self.layer3 = nn.Dense(512, 10)
def construct(self, x):
x = self.flatten(x)
x = self.layer1(x)
x = self.relu1(x)
x = self.layer2(x)
x = self.relu2(x)
logits = self.layer3(x)
return logits
net = Network()
# Configure the recompute of relu operator
net.relu1.recompute()
net.relu2.recompute()
Training the Network
In this step, we need to define the loss function, the optimizer, and the training process, and this part is consistent with the data-parallel model:
from mindspore import nn
import mindspore as ms
optimizer = nn.SGD(net.trainable_params(), 1e-2)
loss_fn = nn.CrossEntropyLoss()
def forward_fn(data, target):
logits = net(data)
loss = loss_fn(logits, target)
return loss, logits
grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
grad_reducer = nn.DistributedGradReducer(optimizer.parameters)
for epoch in range(1):
i = 0
for image, label in data_set:
(loss_value, _), grads = grad_fn(image, label)
grads = grad_reducer(grads)
optimizer(grads)
if i % 10 == 0:
print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_value))
i += 1
Running Stand-alone 8-card Script
Next, the corresponding script is called by the command. Take the mpirun
startup method, the 8-card distributed training script as an example, and perform the distributed training:
bash run.sh
After training, the log files are saved to the log_output
directory, and by setting the environment variable MS_DEV_SAVE_GRAPHS
to 2, you can print out the IR graphs of the compilation process, where some of the file directories are structured as follows:
├─ log_output
| └─ 1
| ├─ rank.0
| | └─ stdout
| ├─ rank.1
| | └─ stdout
| ...
├─ rank_0
| ├─ xx_validate_xxx.ir
| ...
├─ rank_1
| ├─ xx_validate_xxx.ir
| ...
...
The results on the Loss section are saved in log_output/1/rank.*/stdout
, and the example is as below:
epoch: 0, step: 0, loss is 2.2929618
epoch: 0, step: 10, loss is 2.2396836
epoch: 0, step: 20, loss is 2.2097976
epoch: 0, step: 30, loss is 2.1942225
epoch: 0, step: 40, loss is 2.0986974
epoch: 0, step: 50, loss is 2.0612597
...
Computation graph results is in xx_validate_xxx.ir
before setting up the recomputation:
...
%81(1285) = MatMul(%80, %11) primitive_attrs: {output_names: (output), transpose_a: Bool(0), input_names: (x1, x2), transpose_x2: Bool(1), transpose_x1: Bool(0), transpose_b: Bool(1)} cnode_primal_attrs: {forward_node_name: "MatMul_24422", forward_unique_id: "24422"}
: (<Tensor[Float32], (32, 10)>, <Tensor[Float32], (512, 10)>) -> (<Tensor[Float32], (32, 512)>)
...
%82(1286) = ReluGrad(%81, %10) primitive_attrs: {output_names: (output), input_names: (x)} cnode_primal_attrs: {forward_node_name: "ReLU_24405", forward_unique_id: "24405"}
: (<Tensor[Float32], (32, 512)>, <Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (32, 512)>)
...
%83(1285) = MatMul(%82, %6) primitive_attrs: {output_names: (output), transpose_a: Bool(0), input_names: (x1, x2), transpose_x2: Bool(1), transpose_x1: Bool(0), transpose_b: Bool(1)} cnode_primal_attrs: {forward_node_name: "MatMul_24434", forward_unique_id: "24434"}
: (<Tensor[Float32], (32, 512)>, <Tensor[Float32], (512, 512)>) -> (<Tensor[Float32], (32, 512)>)
...
%84(1286) = ReluGrad(%83, %5) primitive_attrs: {output_names: (output), input_names: (x)} cnode_primal_attrs: {forward_node_name: "ReLU_24408", forward_unique_id: "24408"}
: (<Tensor[Float32], (32, 512)>, <Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (32, 512)>)
...
%85(1285) = MatMul(%0, %84) primitive_attrs: {output_names: (output), transpose_a: Bool(1), input_names: (x1, x2), transpose_x2: Bool(0), transpose_x1: Bool(1), transpose_b: Bool(0)} cnode_primal_attrs: {forward_node_name: "MatMul_24446", forward_unique_id: "24446"}
: (<Tensor[Float32], (32, 784)>, <Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (784, 512)>)
...
After setting the recomputation:
...
%81(1285) = MatMul(%80, %11) primitive_attrs: {output_names: (output), transpose_a: Bool(0), input_names: (x1, x2), transpose_x2: Bool(1), transpose_x1: Bool(0), transpose_b: Bool(1)} cnode_primal_attrs: {forward_node_name: "MatMul_24422", forward_unique_id: "24422"}
: (<Tensor[Float32], (32, 10)>, <Tensor[Float32], (512, 10)>) -> (<Tensor[Float32], (32, 512)>)
...
%84([CNode]1292) = ReLU(%83) {instance name: relu2} primitive_attrs: {output_names: [output], input_names: [x], recompute: Bool(1)} cnode_attrs: {recompute_sub_graph: U64(1), recompute_id: I64(2), duplicated: Bool(1), need_cse_after_recompute: Bool(1)}
: (<Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (32, 512)>)
# Scope: (Default)
%85([CNode]1293) = ReluGrad(%81, %84) primitive_attrs: {output_names: (output), input_names: (x)} cnode_attrs: {recompute_sub_graph: U64(1), target_grad: Bool(1)} cnode_primal_attrs: {forward_node_name: "ReLU_24405", forward_unique_id: "24405"}
: (<Tensor[Float32], (32, 512)>, <Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (32, 512)>)
...
%86(1285) = MatMul(%85, %6) primitive_attrs: {output_names: (output), transpose_a: Bool(0), input_names: (x1, x2), transpose_x2: Bool(1), transpose_x1: Bool(0), transpose_b: Bool(1)} cnode_primal_attrs: {forward_node_name: "MatMul_24434", forward_unique_id: "24434"}
: (<Tensor[Float32], (32, 512)>, <Tensor[Float32], (512, 512)>) -> (<Tensor[Float32], (32, 512)>)
...
%89([CNode]1296) = ReLU(%88) {instance name: relu2} primitive_attrs: {output_names: [output], input_names: [x], recompute: Bool(1)} cnode_attrs: {recompute_sub_graph: U64(0), recompute_id: I64(1), duplicated: Bool(1), need_cse_after_recompute: Bool(1)}
: (<Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (32, 512)>)
# Scope: (Default)
%90([CNode]1297) = ReluGrad(%86, %89) primitive_attrs: {output_names: (output), input_names: (x)} cnode_attrs: {recompute_sub_graph: U64(0), target_grad: Bool(1)} cnode_primal_attrs: {forward_node_name: "ReLU_24408", forward_unique_id: "24408"}
: (<Tensor[Float32], (32, 512)>, <Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (32, 512)>)
...
%91(1285) = MatMul(%0, %90) primitive_attrs: {output_names: (output), transpose_a: Bool(1), input_names: (x1, x2), transpose_x2: Bool(0), transpose_x1: Bool(1), transpose_b: Bool(0)} cnode_primal_attrs: {forward_node_name: "MatMul_24446", forward_unique_id: "24446"}
: (<Tensor[Float32], (32, 784)>, <Tensor[Float32], (32, 512)>) -> (<Tensor[Float32], (784, 512)>)
...
It can be seen that the ReLU
operator is copied out in one copy as input to the reverse operator ReluGrad
.