msrun启动

查看源文件

概述

msrun动态组网启动方式的封装,用户可使用msrun以单个命令行指令的方式在各节点拉起多进程分布式任务,并且无需手动设置动态组网环境变量msrun同时支持AscendGPUCPU后端。与动态组网启动方式一样,msrun无需依赖第三方库以及配置文件。

  • msrun在用户安装MindSpore后即可使用,可使用指令msrun --help查看支持参数。

  • msrun支持图模式以及PyNative模式

命令行参数列表:

参数 功能 类型 取值 说明
--worker_num 参与分布式任务的Worker进程总数。 Integer 大于0的整数。默认值为8。 每个节点上启动的Worker总数应当等于此参数:
若总数大于此参数,多余的Worker进程会注册失败;
若总数小于此参数,集群会在等待一段超时时间后,
提示任务拉起失败并退出,
超时时间窗大小可通过参数cluster_time_out配置。
--local_worker_num 当前节点上拉起的Worker进程数。 Integer 大于0的整数。默认值为8。 当此参数与worker_num保持一致时,代表所有Worker进程在本地执行,
此场景下node_rank值会被忽略。
--master_addr 指定Scheduler的IP地址。 String 合法的IP地址。默认为127.0.0.1。 msrun会自动检测在哪个节点拉起Scheduler进程,用户无需关心。
若无法查找到对应的地址,训练任务会拉起失败。
当前版本暂不支持IPv6地址。
当前版本msrun使用ip -j addr指令查询当前节点地址,
需要用户环境支持此指令。
--master_port 指定Scheduler绑定端口号。 Integer 1024~65535范围内的端口号。默认为8118。
--node_rank 当前节点的索引。 Integer 大于0的整数。默认值为-1。 单机多卡场景下,此参数会被忽略。
多机多卡场景下,
若不设置此参数,Worker进程的rank_id会被自动分配;
若设置,则会按照索引为各节点上的Worker进程分配rank_id。
若每个节点Worker进程数量不同,建议不配置此参数,
以自动分配rank_id。
--log_dir Worker以及Scheduler日志输出路径。 String 文件夹路径。默认为当前目录。 若路径不存在,msrun会递归创建文件夹。
日志格式如下:对于Scheduler进程,日志名为scheduler.log
对于Worker进程,日志名为worker_[rank].log
其中rank后缀与分配给Worker的rank_id一致,
但在未设置node_rank且多机多卡场景下,它们可能不一致。
建议执行grep -rn "Global rank id"指令查看各Worker的rank_id
--join msrun是否等待Worker以及Scheduler退出。 Bool True或者False。默认为False。 若设置为False,msrun在拉起进程后会立刻退出,
查看日志确认分布式任务是否正常执行。
若设置为True,msrun会等待所有进程退出后,收集异常日志并退出。
--cluster_time_out 集群组网超时时间,单位为秒。 Integer 默认为600秒。 此参数代表在集群组网的等待时间。
若超出此时间窗口依然没有worker_num数量的Worker注册成功,则任务拉起失败。
--bind_core 开启进程绑核。 Bool True或者False。默认为False。 若用户配置此参数,msrun会平均分配CPU核,将其绑定到拉起的分布式进程上。
--sim_level 设置单卡模拟编译等级。 Integer 默认为-1,即关闭单卡模拟编译功能。 若用户配置此参数,msrun只会拉起单进程模拟编译,不做算子执行。此功能通常用于调试大规模分布式训练并行策略,在编译阶段提前发现内存和策略问题。
若设置为0,只做前端图编译;若设置为1,进一步执行后端图编译,在执行图阶段退出。
--sim_rank_id 单卡模拟编译的rank_id。 Integer 默认为0。 设置单卡模拟编译进程的rank_id。
--rank_table_file rank_table配置文件,只在昇腾平台下有效。 String rank_table配置文件路径,默认为空。 此参数代表昇腾平台下的rank_table配置文件,描述当前分布式集群。
task_script 用户Python脚本。 String 合法的脚本路径。 一般情况下,此参数为python脚本路径,msrun会默认以python task_script task_script_args方式拉起进程。
msrun还支持此参数为pytest,此场景下任务脚本及任务参数在参数task_script_args传递。
task_script_args 用户Python脚本的参数。 参数列表。 例如:msrun --worker_num=8 --local_worker_num=8 train.py --device_target=Ascend --dataset_path=/path/to/dataset

环境变量

下表是用户脚本中能够使用的环境变量,它们由msrun设置:

环境变量 功能 取值
MS_ROLE 本进程角色。 当前版本msrun导出下面两个值:
  • MS_SCHED:代表Scheduler进程。
  • MS_WORKER:代表Worker进程。
MS_SCHED_HOST 用户指定的Scheduler的IP地址。 与参数--master_addr相同。
MS_SCHED_PORT 用户指定的Scheduler绑定端口号。 与参数--master_port相同。
MS_WORKER_NUM 用户指定的Worker进程总数。 与参数--worker_num相同。
MS_TOPO_TIMEOUT 集群组网超时时间。 与参数--cluster_time_out相同。
RANK_SIZE 用户指定的Worker进程总数。 与参数--worker_num相同。
RANK_ID 为Worker进程分配的rank_id。 多机多卡场景下,若没有设置--node_rank参数,RANK_ID只会在集群初始化后被导出。
因此要使用此环境变量,建议正确设置--node_rank参数。

msrun作为动态组网启动方式的封装,所有用户可自定义配置的环境变量可参考动态组网环境变量

操作实践

启动脚本在各硬件平台下一致,下面以Ascend为例演示如何编写启动脚本:

您可以在这里下载完整的样例代码:startup_method

目录结构如下:

└─ sample_code
    ├─ startup_method
       ├── msrun_1.sh
       ├── msrun_2.sh
       ├── msrun_single.sh
       ├── net.py
    ...

其中,net.py是定义网络结构和训练过程,msrun_single.sh是以msrun启动的单机多卡执行脚本;msrun_1.shmsrun_2.sh是以msrun启动的多机多卡执行脚本,分别在不同节点上执行。

1. 准备Python训练脚本

这里以数据并行为例,训练一个MNIST数据集的识别网络。

首先指定运行模式、硬件设备等,与单卡脚本不同,并行脚本还需指定并行模式等配置项,并通过init()初始化HCCL、NCCL或MCCL通信域。此处不设置device_target会自动指定为MindSpore包对应的后端硬件设备。

import mindspore as ms
from mindspore.communication import init

ms.set_context(mode=ms.GRAPH_MODE)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
init()
ms.set_seed(1)

然后构建如下网络:

from mindspore import nn

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros")
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.flatten(x)
        logits = self.relu(self.fc(x))
        return logits
net = Network()

最后是数据集处理和定义训练过程:

import os
from mindspore import nn
import mindspore as ms
import mindspore.dataset as ds
from mindspore.communication import get_rank, get_group_size

def create_dataset(batch_size):
    dataset_path = os.getenv("DATA_PATH")
    rank_id = get_rank()
    rank_size = get_group_size()
    dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id)
    image_transforms = [
        ds.vision.Rescale(1.0 / 255.0, 0),
        ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        ds.vision.HWC2CHW()
    ]
    label_transform = ds.transforms.TypeCast(ms.int32)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

data_set = create_dataset(32)
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(net.trainable_params(), 1e-2)

def forward_fn(data, label):
    logits = net(data)
    loss = loss_fn(logits, label)
    return loss, logits

grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
grad_reducer = nn.DistributedGradReducer(optimizer.parameters)

for epoch in range(10):
    i = 0
    for data, label in data_set:
        (loss, _), grads = grad_fn(data, label)
        grads = grad_reducer(grads)
        optimizer(grads)
        if i % 10 == 0:
            print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss))
        i += 1

2. 准备启动脚本

对于msrun来说单机多卡和多机多卡执行指令类似,单机多卡只需将参数worker_numlocal_worker_num保持相同即可,且单机多卡场景下无需设置master_addr,默认为127.0.0.1

单机多卡

下面以执行单机8卡训练为例:

脚本msrun_single.sh使用msrun指令在当前节点拉起1个Scheduler进程以及8个Worker进程(无需设置master_addr,默认为127.0.0.1;单机无需设置node_rank):

EXEC_PATH=$(pwd)
if [ ! -d "${EXEC_PATH}/MNIST_Data" ]; then
    if [ ! -f "${EXEC_PATH}/MNIST_Data.zip" ]; then
        wget http://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip
    fi
    unzip MNIST_Data.zip
fi
export DATA_PATH=${EXEC_PATH}/MNIST_Data/train/

rm -rf msrun_log
mkdir msrun_log
echo "start training"

msrun --worker_num=8 --local_worker_num=8 --master_port=8118 --log_dir=msrun_log --join=True --cluster_time_out=300 net.py

执行指令:

bash msrun_single.sh

即可执行单机8卡分布式训练任务,日志文件会保存到./msrun_log目录下,结果保存在./msrun_log/worker_*.log中,Loss结果如下:

epoch: 0, step: 0, loss is 2.3499548
epoch: 0, step: 10, loss is 1.6682479
epoch: 0, step: 20, loss is 1.4237018
epoch: 0, step: 30, loss is 1.0437132
epoch: 0, step: 40, loss is 1.0643986
epoch: 0, step: 50, loss is 1.1021575
epoch: 0, step: 60, loss is 0.8510884
epoch: 0, step: 70, loss is 1.0581372
epoch: 0, step: 80, loss is 1.0076828
epoch: 0, step: 90, loss is 0.88950706
...

多机多卡

下面以执行2机8卡训练,每台机器执行启动4个Worker为例:

脚本msrun_1.sh在节点1上执行,使用msrun指令拉起1个Scheduler进程以及4个Worker进程,配置master_addr为节点1的IP地址(msrun会自动检测到当前节点IP与master_addr匹配而拉起Scheduler进程),通过node_rank设置当前节点为0号节点:

EXEC_PATH=$(pwd)
if [ ! -d "${EXEC_PATH}/MNIST_Data" ]; then
    if [ ! -f "${EXEC_PATH}/MNIST_Data.zip" ]; then
        wget http://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip
    fi
    unzip MNIST_Data.zip
fi
export DATA_PATH=${EXEC_PATH}/MNIST_Data/train/

rm -rf msrun_log
mkdir msrun_log
echo "start training"

msrun --worker_num=8 --local_worker_num=4 --master_addr=<node_1 ip address> --master_port=8118 --node_rank=0 --log_dir=msrun_log --join=True --cluster_time_out=300 net.py

脚本msrun_2.sh在节点2上执行,使用msrun指令拉起4个Worker进程,配置master_addr为节点1的IP地址,通过node_rank设置当前节点为1号节点:

EXEC_PATH=$(pwd)
if [ ! -d "${EXEC_PATH}/MNIST_Data" ]; then
    if [ ! -f "${EXEC_PATH}/MNIST_Data.zip" ]; then
        wget http://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip
    fi
    unzip MNIST_Data.zip
fi
export DATA_PATH=${EXEC_PATH}/MNIST_Data/train/

rm -rf msrun_log
mkdir msrun_log
echo "start training"

msrun --worker_num=8 --local_worker_num=4 --master_addr=<node_1 ip address> --master_port=8118 --node_rank=1 --log_dir=msrun_log --join=True --cluster_time_out=300 net.py

节点2和节点1的指令差别在于node_rank不同。

在节点1执行:

bash msrun_1.sh

在节点2执行:

bash msrun_2.sh

即可执行2机8卡分布式训练任务,日志文件会保存到./msrun_log目录下,结果保存在./msrun_log/worker_*.log中,Loss结果如下:

epoch: 0, step: 0, loss is 2.3499548
epoch: 0, step: 10, loss is 1.6682479
epoch: 0, step: 20, loss is 1.4237018
epoch: 0, step: 30, loss is 1.0437132
epoch: 0, step: 40, loss is 1.0643986
epoch: 0, step: 50, loss is 1.1021575
epoch: 0, step: 60, loss is 0.8510884
epoch: 0, step: 70, loss is 1.0581372
epoch: 0, step: 80, loss is 1.0076828
epoch: 0, step: 90, loss is 0.88950706
...