Multi-device Model Weight Sharding
After the model training is complete, the trained weights can be loaded for inference. The GPU memory required for inference is significantly lower than that required for training. Therefore, the model weights need to be sharded and loaded again.
Weight Sharding
MindSpore uses the strategy file to manage distributed weights. After model development, the distribution attribute of weights in a frontend parallel network is determined by the basic module. Therefore, the weight sharding strategy information sharded_state_dict
is first added to the distributed basic module.
class ColumnParallelLinear(nn.Cell):
...
def sharded_state_dict(self):
w_shard = (self.tensor_parallel_group_size, 1) if self.transpose_b else (1, self.tensor_parallel_group_size)
state_dict = {}
if not self.skip_weight_param_allocation:
state_dict[self.weight.name] = {'shape': self.weight.shape,
'shard': w_shard}
if self.has_bias:
state_dict[self.bias.name] = {'shape': self.bias.shape,
'shard': (self.tensor_parallel_group_size,)}
return state_dict
class RowParallelLinear(nn.Cell):
...
def sharded_state_dict(self):
w_shard = (1, self.tensor_parallel_group_size) if self.transpose_b else (self.tensor_parallel_group_size, 1)
state_dict = {}
state_dict[self.weight.name] = {'shape': self.weight.shape,
'shard': w_shard}
if self.has_bias:
state_dict[self.bias.name] = {'shape': self.bias.shape,
'shard': (1,)}
return state_dict
class VocabParallelEmbedding(nn.Cell):
...
def sharded_state_dict(self):
"""provide the sharded state dict based on the config"""
w_shard = (self.tensor_model_parallel_size, 1)
state_dict = {}
state_dict[self.weight.name] = {'shape': self.weight.shape,
'shard': w_shard}
return state_dict
The sharded_state_dict
of the entire network may be generated based on the sharded_state_dict
of the basic parallel module in the network. The network strategy information is obtained by calling generate_state_dict
and saved as a strategy file by calling save_strategy_file
.
def _update_sharded_state_dict(network: nn.Cell, dict_: dict):
cells = network.name_cells()
for _, subcell in cells.items():
if subcell == network:
continue
if isinstance(subcell, (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)):
dict_.update(subcell.sharded_state_dict())
else:
_update_sharded_state_dict(subcell, dict_)
def generate_state_dict(network):
state_dict = {
"total_rank": get_group_size(),
"stage_rank_size": get_group_size(),
"stage": 0
}
model_state_dict = {}
_update_sharded_state_dict(network=network, dict_=model_state_dict)
state_dict['model'] = model_state_dict
return state_dict
def save_strategy_file(state_dict, strategy_file_name):
import os
import stat
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
stra = ckpt_strategy()
total_rank = state_dict["total_rank"]
stage_rank_size = state_dict["stage_rank_size"]
stage = state_dict["stage"]
model_param = state_dict["model"]
optimizer_param = state_dict["optimizer"]
stra.current_stage = 0
model_param.update(optimizer_param)
for name, item in model_param.items():
if "shard" not in item or "shape" not in item:
continue
opt_weight_shard_step = item["opt_weight_shard_step"] \
if "opt_weight_shard_step" in item.keys() else 0
opt_weight_shard_size = item["opt_weight_shard_size"] \
if "opt_weight_shard_size" in item.keys() else 0
strategy_item = stra.parallel_strategy_item.add()
strategy_item.node_name = name
parallel_strategys = strategy_item.parallel_strategys
parallel_strategys.stage = stage
shard = item["shard"]
shape = item["shape"]
parallel_strategy = parallel_strategys.parallel_strategy.add()
shard_mul = 1
for ele in shard:
parallel_strategy.dim.append(ele)
shard_mul = shard_mul * ele
layout_item = stra.parallel_layout_item.add()
layout_item.param_name = name
parallel_layouts = layout_item.parallel_layouts
parallel_layouts.field = 0
parallel_layouts.opt_weight_shard_step = opt_weight_shard_step
parallel_layouts.opt_weight_shard_size = opt_weight_shard_size
dev_matrix = parallel_layouts.dev_matrix.add()
repeat_calc_num = 1
if stage_rank_size == shard_mul:
repeat_calc_num = 1
elif stage_rank_size % shard_mul == 0:
repeat_calc_num = stage_rank_size // shard_mul
else:
raise ValueError(f"For {name}, the shard{shard} requires {shard_mul} devices, "
f"but the device number of this stage is {stage_rank_size}, "
f"it can not be divisible by {shard_mul}")
if repeat_calc_num != 1:
dev_matrix.dim.append(repeat_calc_num)
for ele in shard:
dev_matrix.dim.append(ele)
tensor_map = parallel_layouts.tensor_map.add()
shape_len = len(shape)
index = shape_len - 1
for _ in range(shape_len):
tensor_map.dim.append(index)
index = index - 1
param_split_shape = parallel_layouts.param_split_shape.add()
for ele in shape:
param_split_shape.dim.append(ele)
try:
if os.path.exists(strategy_file_name):
os.chmod(strategy_file_name, stat.S_IWUSR)
if "/" in strategy_file_name:
real_path = os.path.abspath(strategy_file_name[:strategy_file_name.rfind("/")])
os.makedirs(real_path, exist_ok=True)
flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
with os.fdopen(os.open(strategy_file_name, flags_, 0o750), 'wb') as f:
f.write(stra.SerializeToString())
os.chmod(strategy_file_name, stat.S_IRUSR)
except BaseException as e:
logger.critical(f"Failed to save the checkpoint file {strategy_file_name}. Maybe don't have "
"the permission to write files, or the disk space is insufficient and so on.")
raise e
After the parallel strategy file of the inference network is obtained, the training weight can be converted into the weight required for inference according to the method of Executing Distributed Checkpoint Transformation.
For details about the end-to-end weight sharding code project, see Weight Sharding.
Weight Loading
For details about distributed weight loading, see Loading the Transformed Checkpoint Files.