数据集切分
概述
在进行分布式训练时,以图片数据为例,当单张图片的大小过大时,如遥感卫星等大幅面图片,当单张图片过大时,需要对图片进行切分,每张卡读取一部分图片,进行分布式训练。处理数据集切分的场景,需要配合模型并行一起才能达到预期的降低显存的效果,因此,基于自动并行提供了该项功能。本教程使用的样例不是大幅面的网络,仅作示例。真实应用到大幅面的网络时,往往需要详细设计并行策略。
数据集切分仅支持全自动模式和半自动模式,在数据并行模式下不涉及。
相关接口:
mindspore.dataset.vision.SlicePatches(num_height=1, num_width=1)
:在水平和垂直方向上将Tensor切片为多个块。适合于Tensor高宽较大的使用场景。其中num_height
为垂直方向的切块数量,num_width
为水平方向的切块数量。更多参数可以参考SlicePatches。mindspore.set_auto_parallel_context(dataset_strategy=((1, 1, 1, 8), (8,)))
:表示数据集分片策略。dataset_strategy
接口有以下几点限制:每个输入至多允许在一维进行切分。如支持
set_auto_parallel_context(dataset_strategy=((1, 1, 1, 8), (8,)))
或者dataset_strategy=((1, 1, 1, 8), (1,))
,每个输入至多切分了一维;但是不支持dataset_strategy=((1, 1, 4, 2), (1,))
,其第一个输入切分了两维。维度最高的一个输入,切分的数目,一定要比其他维度的多。如支持
dataset_strategy=((1, 1, 1, 8), (8,))
或者dataset_strategy=((1, 1, 1, 1), (1,))
,其维度最多的输入为第一个输入,切分份数为8,其余输入切分均不超过8;但是不支持dataset_strategy=((1, 1, 1, 1), (8,))
,其维度最多的输入为第一维,切分份数为1,但是其第二个输入切分份数却为8,超过了第一个输入的切分份数。
操作实践
样例代码说明
下载完整的样例代码:dataset_slice。
目录结构如下:
└─ sample_code
├─ dataset_slice
├── train.py
└── run.sh
...
其中,train.py
是定义网络结构和训练过程的脚本。run.sh
是执行脚本。
配置分布式环境
首先通过context接口指定运行模式、运行设备、运行卡号等,并行模式为半自动并行模式,并通过init初始化HCCL或NCCL通信。此外,配置数据集dataset_strategy
切分策略为((1, 1, 1, 4), (1,)),代表垂直方向不切分,水平方向切分4块。
import mindspore as ms
from mindspore.communication import init
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
init()
slice_h_num = 1
slice_w_num = 4
ms.set_auto_parallel_context(dataset_strategy=((1, 1, slice_h_num, slice_w_num), (1,)))
数据集加载
使用数据集切分时,需要同时调用数据集的SlicePatches
接口去构造数据集,并且,为了保证各卡读入数据一致,需要对数据集固定随机数种子。
import os
import mindspore.dataset as ds
from mindspore import nn
ds.config.set_seed(1000) # set dataset seed to make sure that all cards read the same data
def create_dataset(batch_size):
dataset_path = os.getenv("DATA_PATH")
dataset = ds.MnistDataset(dataset_path)
image_transforms = [
ds.vision.Rescale(1.0 / 255.0, 0),
ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
ds.vision.HWC2CHW()
]
label_transform = ds.transforms.TypeCast(ms.int32)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
# slice image
slice_patchs_img_op = ds.vision.SlicePatches(slice_h_num, slice_w_num)
img_cols = ['img' + str(x) for x in range(slice_h_num * slice_w_num)]
dataset = dataset.map(operations=slice_patchs_img_op, input_columns="image", output_columns=img_cols)
dataset = dataset.project([img_cols[get_rank() % (slice_h_num * slice_w_num)], "label"])
dataset = dataset.batch(batch_size)
return dataset
data_set = create_dataset(32)
网络定义
此处网络定义与单卡模型一致:
from mindspore import nn
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.layer1 = nn.Dense(28*28, 512)
self.relu1 = nn.ReLU()
self.layer2 = nn.Dense(512, 512)
self.relu2 = nn.ReLU()
self.layer3 = nn.Dense(512, 10)
def construct(self, x):
x = self.flatten(x)
x = self.layer1(x)
x = self.relu1(x)
x = self.layer2(x)
x = self.relu2(x)
logits = self.layer3(x)
return logits
net = Network()
训练网络
在这一步,需要定义损失函数、优化器以及训练过程,本例采用函数式方式编写,这部分与单卡模型一致:
from mindspore import nn
import mindspore as ms
optimizer = nn.SGD(net.trainable_params(), 1e-2)
loss_fn = nn.CrossEntropyLoss()
def forward_fn(data, target):
logits = net(data)
loss = loss_fn(logits, target)
return loss, logits
grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
for epoch in range(1):
i = 0
for image, label in data_set:
(loss_value, _), grads = grad_fn(image, label)
optimizer(grads)
if i % 10 == 0:
print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_value))
i += 1
运行单机8卡脚本
接下来通过命令调用对应的脚本,以mpirun
启动方式,8卡的分布式训练脚本为例,进行分布式训练:
bash run.sh
训练完后,日志文件保存到log_output
目录下,关于Loss部分结果保存在log_output/1/rank.*/stdout
中,示例如下:
epoch: 0, step: 0, loss is 2.290361
epoch: 0, step: 10, loss is 1.9397954
epoch: 0, step: 20, loss is 1.4175975
epoch: 0, step: 30, loss is 1.0338318
epoch: 0, step: 40, loss is 0.64065826
epoch: 0, step: 50, loss is 0.8479157
...