Operator-level Parallelism
Overview
With the development of deep learning, network models are becoming larger and larger, such as trillions of parametric models have emerged in the field of NLP, and the model capacity far exceeds the memory capacity of a single device, making it impossible to train on a single card or data parallel. Operator-level parallelism is achieved by slicing the tensor involved in each operator in the network model and distributing the operators to multiple devices, reducing memory consumption on a single device, thus enabling the training of large models.
MindSpore provides two levels of granularity: operator-level parallelism and higher-order operator-level parallelism. Operator-level parallelism describes the tensor dimensionality distribution through a simple slicing strategy, which meets the requirements of most scenarios. Higher-order operator parallelism supports complex slicing scenarios through open device scheduling descriptions.
Operator-Level Parallel Practice
The illustration of the operator-level parallel operation is based on the Ascend single-machine 8-card example:
Sample Code Description
Download the complete sample code here: distributed_operator_parallel.
The directory structure is as follows:
└─ sample_code
├─ distributed_operator_parallel
├── distributed_operator_parallel.py
├── run.sh
└── ...
...
Among them, distributed_operator_parallel.py
is the script that defines the network structure and the training process. run.sh
is the execution script.
Configuring the Distributed Environment
Unlike single card scripts, parallel scripts also need to initialize the communication domain through the init
interface. In addition, limiting the model's maximum available device memory via the max_size
of the set_memory
interface leaves enough device memory for communication on the Ascend hardware platform.
import mindspore as ms
from mindspore.communication import init
ms.set_context(mode=ms.GRAPH_MODE)
ms.runtime.set_memory(max_size="28GB")
init()
ms.set_seed(1)
Loading the Dataset
In the operator-level parallel scenario, the dataset is loaded in the same way as single-card is loaded, with the following code:
import os
import mindspore.dataset as ds
def create_dataset(batch_size):
dataset_path = os.getenv("DATA_PATH")
dataset = ds.MnistDataset(dataset_path)
image_transforms = [
ds.vision.Rescale(1.0 / 255.0, 0),
ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
ds.vision.HWC2CHW()
]
label_transform = ds.transforms.TypeCast(ms.int32)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset
data_set = create_dataset(32)
Defining the Network
In the current semi-automatic parallel mode, the network needs to be defined with ops operators(Primitive). Users can manually configure the slicing strategy for some operators based on a single-card network, e.g., the network structure after configuring the strategy is:
import mindspore as ms
from mindspore import nn, ops
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = ops.Flatten()
self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32))
self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32))
self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32))
self.matmul1 = ops.MatMul().shard(((2, 4), (4, 1)))
self.relu1 = ops.ReLU().shard(((4, 1),))
self.matmul2 = ops.MatMul().shard(((1, 8), (8, 1)))
self.relu2 = ops.ReLU().shard(((8, 1),))
self.matmul3 = ops.MatMul()
def construct(self, x):
x = self.flatten(x)
x = self.matmul1(x, self.fc1_weight)
x = self.relu1(x)
x = self.matmul2(x, self.fc2_weight)
x = self.relu2(x)
logits = self.matmul3(x, self.fc3_weight)
return logits
The ops.MatMul()
and ops.ReLU()
operators for the above networks are configured with slicing strategy, in the case of ops.MatMul().shard(((2, 4), (4, 1)))
, which has a slicing strategy of: rows of the first input are sliced in 2 parts and columns in 4 parts; rows of the second input are sliced in 4 parts. For ops.ReLU().shard(((8, 1),))
, its slicing strategy is: the row of the first input is sliced in 8 parts. Note that since the two ops.ReLU()
here have different slicing strategies, i.e., ops.ReLU().shard(((4, 1),))
and ops.ReLU().shard(((8, 1),))
have to be defined twice separately.
Training Network Definition
In this step, we need to define the loss function, the optimizer, and the training process. Note that due to the huge number of parameters of the large model, the graphics memory will be far from sufficient if parameter initialization is performed when defining the network on a single card. Therefore, delayed initialization is required when defining the network in conjunction with the no_init_parameters
interface to delay parameter initialization until the parallel multicard phase. Here both network and optimizer definitions need to be delayed initialized.
from mindspore.nn.utils import no_init_parameters
with no_init_parameters():
net = Network()
optimizer = nn.SGD(net.trainable_params(), 1e-2)
loss_fn = nn.CrossEntropyLoss()
def forward_fn(data, target):
logits = net(data)
loss = loss_fn(logits, target)
return loss, logits
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)
def train_step(inputs, targets):
(loss_value, _), grads = grad_fn(inputs, targets)
optimizer(grads)
return loss_value
Parallel Configuration
We need to further set up the parallelism-related configuration by specifying the parallel mode semi-auto
as semi-automatic parallel.
from mindspore.parallel.auto_parallel import AutoParallel
parallel_net = AutoParallel(train_step, parallel_mode="semi_auto")
Training Loop
This step performs a training loop, the outer loop is the number of epochs to train and the inner loop traverses the dataset, calling parallel_net
to train and obtain the loss values.
for epoch in range(10):
i = 0
for image, label in data_set:
loss_output = parallel_net(image, label)
if i % 10 == 0:
print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_output))
i += 1
Running the Single-machine Eight-card Script
Next, the corresponding scripts are invoked by commands, using the msrun
startup method and the 8-card distributed training script as an example of distributed training:
bash run.sh
After training, the log files are saved to the log_output
directory, where part of the file directory structure is as follows:
└─ log_output
├─ scheduler.log
├─ worker_0.log
├─ worker_1.log
...
The results on the Loss section are saved in log_output/worker_*.log
, and example is as follows:
epoch: 0 step: 0, loss is 2.3016002
epoch: 0 step: 10, loss is 2.2889402
epoch: 0 step: 20, loss is 2.2843816
epoch: 0 step: 30, loss is 2.248126
epoch: 0 step: 40, loss is 2.1581488
epoch: 0 step: 50, loss is 1.8051043
...
Other startup methods such as mpirun
and rank table
startup can be found in startup methods.
Higher-Order Operator-Level Parallel Practice
An illustration of higher-order operator-level parallel operations follows, using the Ascend single 8-card as an example:
Sample Code Description
Download the complete sample code here: distributed_operator_parallel.
The directory structure is as follows:
└─ sample_code
├─ distributed_operator_parallel
├── advanced_distributed_operator_parallel.py
├── run_advanced.sh
└── ...
...
Among them, advanced_distributed_operator_parallel.py
is the script that defines the network structure and the training process. run_advanced.sh
is the execution script.
Environment Configuration
Before performing higher-order operator-level parallelism, the environment is first configured, and the process is consistent with operator-level parallelism, which can be found in Configure Distributed Environment and Dataset Load.
Defining the Network
Higher-order operator-level parallelism extends the functionality of the shard
interface by additionally accepting the new quantity type tuple(Layout)
type for both the in_strategy
/out_strategy
in-parameters of the shard
interface.
Layout is initialized using the device matrix, and requires an alias for each axis of the device matrix, such as "layout = Layout((2, 2, 2), name = ("dp", "sp", "mp"))", which describes a total of 8 cards arranged in the shape of (2, 2, 2), and each axis is aliased to "dp", "sp", "mp".
The call to Layout passes in these axes, and each tensor picks which axis of the device each dimension is expected to map to according to its shape, and also determines the number of parts to be sliced, e.g., here "dp" means 2 parts in 2 devices in the highest dimension of the device layout; "sp" means 2 parts in 2 devices in the middle dimension of the device layout; "mp" means 2 parts in 2 devices in the lowest dimension of the device layout. In particular, one dimension of the tensor may be mapped to multiple dimensions of the device to express multiple slices in one dimension.
import mindspore as ms
from mindspore import nn, ops
class Network(nn.Cell):
"""Network"""
def __init__(self):
super().__init__()
self.flatten = ops.Flatten()
self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32))
self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32))
self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32))
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
layout2 = Layout((8,), ("tp",))
self.matmul1 = ops.MatMul().shard((layout("mp", ("sp", "dp")), layout(("sp", "dp"), "None")))
self.relu1 = ops.ReLU().shard(((4, 1),))
self.matmul2 = ops.MatMul().shard((layout2("None", "tp"), layout2("tp", "None")))
self.relu2 = ops.ReLU().shard(((8, 1),))
self.matmul3 = ops.MatMul()
def construct(self, x):
x = self.flatten(x)
x = self.matmul1(x, self.fc1_weight)
x = self.relu1(x)
x = self.matmul2(x, self.fc2_weight)
x = self.relu2(x)
logits = self.matmul3(x, self.fc3_weight)
return logits
In the network defined above, self.matmul1 = ops.MatMul().shard((layout("mp", ("sp", "dp")), layout(("sp", "dp")), "None"))
The layout for slicing the input tensor x is layout("mp", ("sp ", "dp"))
, i.e., the first dimension is sliced into 2 parts by mp, and the second dimension combines sp and dp for a total of 2*2=4 parts.
The layout for slicing the weight self.fc1_weight is layout(("sp", "dp"), "None")
, i.e., the first dimension merges sp and dp and slices it into 4 parts, and the second dimension is not sliced.
Similarly, self.matmul2 = ops.MatMul().shard((layout2("None", "tp"), layout2("tp", "None")))
When slicing the input tensor x first dimension by rows not sliced and columns sliced into 8 parts by tp, and when slicing the weight self.fc2_weight, the rows are sliced into 8 parts by tp and the columns are not sliced.
Taking self.matmul1 = ops.MatMul().shard((layout("mp", ("sp", "dp")), layout(("sp", "dp"), "None"))
as an example, the slicing will produce the following table of device and data slice mappings:
device coordinates (dp, sp, mp) |
input x slice |
weight fc1_weight slice |
---|---|---|
(0, 0, 0) |
|
|
(0, 0, 1) |
|
|
(0, 1, 0) |
|
|
(0, 1, 1) |
|
|
(1, 0, 0) |
|
|
(1, 0, 1) |
|
|
(1, 1, 0) |
|
|
(1, 1, 1) |
|
|
Training Process
The training flow for higher-order operator-level parallelism is identical to operator-level parallelism, and can be found in Training Network Definition, Parallel Configuration, and Training Loop.
Running the Single-machine Eight-card Script
Next, the corresponding scripts are invoked by commands, using the msrun
startup method and the 8-card distributed training script as an example of distributed training:
bash run_advanced.sh
After training, the log files are saved to the advanced_log_output
directory, where part of the file directory structure is as follows:
└─ advanced_log_output
├─ scheduler.log
├─ worker_0.log
├─ worker_1.log
...
The results are saved in advanced_log_output/worker_*.log
, and example is as follows:
epoch: 0 step: 0, loss is 2.3016002
epoch: 0 step: 10, loss is 2.2889402
epoch: 0 step: 20, loss is 2.2843816
epoch: 0 step: 30, loss is 2.248126
epoch: 0 step: 40, loss is 2.1581488
epoch: 0 step: 50, loss is 1.8051043
...
Other startup methods such as mpirun
and rank table
startup can be found in startup methods.