# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transform distributed safetensors"""
from __future__ import absolute_import
import os
import glob
import math
import json
from collections import defaultdict
import time
import multiprocessing as mp
import numpy as np
from safetensors.numpy import save_file, load_file
from safetensors import safe_open
import mindspore as ms
from mindspore import log as logger
from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
_get_needed_rank_transform_operator_map_by_layouts, \
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
_extract_layout_item, _load_tensor_shape, _apply_operator
from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
_convert_to_list
def _progress_bar(iterable, total=None):
"""
Decorate an iterable object, returning an iterator which acts exactly
like the original iterable, but prints a dynamically updating
progressbar every time a value is requested.
"""
if total is None:
total = len(iterable)
start_time = time.time()
def print_progress_bar(iteration):
percent = f"{100 * (iteration / float(total)):.1f}"
bar_length = 40
filled_length = int(bar_length * iteration // total)
bar = '█' * filled_length + '-' * (bar_length - filled_length)
elapsed_time = time.time() - start_time
estimated_total_time = elapsed_time / iteration * total
remaining_time = estimated_total_time - elapsed_time
elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
if iteration == total:
print()
for i, item in enumerate(iterable, start=1):
yield item
print_progress_bar(i)
def _load_and_transform(path, name_map, load_func, transform_func):
if load_func is not None:
param_dict = load_func(path)
else:
param_dict = path
transform_dict = {}
for k, v in param_dict.items():
new_name = name_map.get(k, k) if name_map is not None else k
transform_dict[new_name] = transform_func(v, new_name)
return transform_dict
def _check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file):
"""check _transform_safetensors input"""
if not isinstance(ckpt_prefix, str):
raise TypeError("The ckpt_prefix should be a str.")
if src_strategy_file and os.path.dirname(src_strategy_file) and not os.path.exists(
os.path.dirname(src_strategy_file)):
raise ValueError("The director of src_strategy_file: {} is not exists.".
format(os.path.dirname(src_strategy_file)))
if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
os.path.dirname(dst_strategy_file)):
raise ValueError("The director of dst_strategy_file: {} is not exists.".
format(os.path.dirname(dst_strategy_file)))
def _check_output_format(output_format):
if output_format not in ["safetensors", "ckpt"]:
raise ValueError(f"For 'transform_safetensors', the output_format must be "
f"'safetensors' or 'ckpt', but got {output_format}.")
def _split_protobuf_strategy(merged_strategy_file):
"""split src_strategy_file by pp"""
dst_parallel_strategy_map = _load_protobuf_strategy(merged_strategy_file)
if not dst_parallel_strategy_map.parallel_strategy_item or not dst_parallel_strategy_map.parallel_layout_item:
raise ValueError(f"The merged strategy file {merged_strategy_file} is empty")
src_dict = {}
for layout_item in dst_parallel_strategy_map.parallel_layout_item:
stage, _ = layout_item.param_name.split('-', 1)
stage = int(stage)
if stage not in src_dict:
src_dict[stage] = {}
parameter_name = layout_item.param_name
layout = layout_item.parallel_layouts
src_dict[stage][parameter_name] = layout
return src_dict
def _transform_safetensors(src_safetensors_dir, dst_safetensors_dir, ckpt_prefix, src_strategy_file=None,
dst_strategy_file=None, process_num=1, output_format="safetensors"):
"""Transform distributed safetensors from source sharding strategy to destination sharding strategy for a rank."""
_check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file)
_check_output_format(output_format)
_make_dir(dst_safetensors_dir, "path")
all_safetensor_files_map = _collect_safetensor_files(src_safetensors_dir)
dst_strategy_dict = _build_searched_strategy(dst_strategy_file)
pipeline_stage_num = _extract_pipeline_stage_num(src_strategy_file)
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file)
if pipeline_stage_num > 1 and dst_stage_num == 1:
stage_dict = _split_protobuf_strategy(src_strategy_file)
processes = []
manager = mp.Manager()
_transform_param_list = manager.list()
for _, src_strategy_dict in stage_dict.items():
p = mp.Process(target=_transform_stage_safetensors,
args=(src_strategy_dict, dst_strategy_dict, ckpt_prefix,
dst_safetensors_dir, output_format, all_safetensor_files_map, process_num,
_transform_param_list))
p.start()
processes.append(p)
for p in processes:
p.join()
_save_final_safetensors(_transform_param_list, output_format)
else:
src_strategy_dict = _build_searched_strategy(src_strategy_file)
_transform_stage_safetensors(src_strategy_dict, dst_strategy_dict, ckpt_prefix,
dst_safetensors_dir, output_format, all_safetensor_files_map, process_num,
_transform_param_list=None)
def _transform_stage_safetensors(src_strategy_dict, dst_strategy_dict, ckpt_prefix,
dst_safetensors_dir, output_format, all_safetensor_files_map, process_num,
_transform_param_list):
"""Transform distributed safetensors by stage"""
src_stage_device_num = _get_device_num_from_strategy(src_strategy_dict)
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
origin_src_strategy_list = _extract_layout_map(src_strategy_dict)
origin_dst_strategy_list = _extract_layout_map(dst_strategy_dict)
needed_rank_list_map = _find_needed_ranks(src_strategy_dict, dst_strategy_dict)
for needed_rank_list, rank in needed_rank_list_map.items():
for needed_rank in needed_rank_list.split("-"):
if int(needed_rank) not in all_safetensor_files_map:
raise ValueError("The safetensor file of rank{} is needed for converting rank{}'s safetensor, "
"but it is missing.".format(needed_rank, rank))
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
if not (len(needed_rank_list_map) == 1 and dst_stage_num > 1) and process_num > len(needed_rank_list_map):
ms.log.warning("The value of process_num cannot be greater than that of needed_rank_list_map.")
process_num = len(needed_rank_list_map)
_transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
dst_safetensors_dir, process_num, output_format,
_transform_param_list)
def _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num):
"""
Distributes files across multiple processes based on file size to balance the processing load.
"""
if process_num == 1:
return [needed_rank_list_map]
# Calculate the size of each file.
# if src==1, dst pp>1, split for pp number.
if len(needed_rank_list_map) == 1:
src_rank = next(iter(needed_rank_list_map.keys()))
dst_list = next(iter(needed_rank_list_map.values()))
size = len(dst_list) // process_num
split_list = [dst_list[i:i + size] for i in range(0, len(dst_list), size)]
part_list_dict = [dict() for _ in range(process_num)]
for index in range(process_num):
part_list_dict[index][src_rank] = split_list[index]
return part_list_dict
rank_size = dict()
for rank_id, file_name in all_safetensor_files_map.items():
tmp_size = os.path.getsize(file_name) / 1024 / 1024
rank_size[rank_id] = tmp_size
# Obtain the rank and size required by all parts.
part_total = []
for index, (k, v) in enumerate(needed_rank_list_map.items()):
tmp_part = []
key_ele = k.split("-")
tmp_size = 0
for ele in key_ele:
tmp_size += rank_size[int(ele)]
tmp_part.append(index)
tmp_part.append(tmp_size)
part_total.append(tmp_part)
# Sort each part by size.
part_total = sorted(part_total, key=lambda x: x[1], reverse=True)
part_list = [[] for _ in range(process_num)]
part_size = [[] for _ in range(process_num)]
for [index, size] in part_total:
min_sum = float('inf')
min_idx = -1
for ele in range(process_num):
if sum(part_size[ele]) < min_sum:
min_sum = sum(part_size[ele])
min_idx = ele
part_list[min_idx].append(index)
part_size[min_idx].append(size)
part_list_dict = [dict() for _ in range(process_num)]
for index, (k, v) in enumerate(needed_rank_list_map.items()):
for idd, ele in enumerate(part_list):
if index in ele:
part_list_dict[idd][k] = v
break
return part_list_dict
def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
dst_safetensors_dir, process_num, output_format,
_transform_param_list):
"""
Transforms safetensors files to a specified format using parallel processing.
"""
# cal param name for every pipeline, save in pipe_param_list.
pipe_num = _extract_pipeline_stage_num(dst_strategy_dict)
pipe_param_list = [None for _ in range(max(pipe_num, process_num))]
if len(needed_rank_list_map) == 1 and pipe_num > 1:
process_num = pipe_num
pipe_param_list = [[] for _ in range(pipe_num)]
layout_map = _convert_to_list(dst_strategy_dict)
for name, layout in layout_map.items():
pipe_param_list[layout[6][0]].append(name)
part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
processes = []
for i in range(process_num):
p = mp.Process(target=_transform_safetensors_single, args=(
part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
p.start()
processes.append(p)
for p in processes:
p.join()
def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
"""Obtain the specified redundant group."""
redundancy_tuple = redundancy_dict.get(param_name)
for rank_list in redundancy_tuple:
for rank in rank_list:
if rank_num % device_num == rank % device_num:
return set(rank_list)
return set()
def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
needed_rank, device_num):
"""Find the rank_id under redundant groups."""
for param_name in pipe_param_list:
rank_num = int(needed_rank)
redundancy_ranks = _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num)
open_file_id = None
if single_param_dict.get(param_name) is None:
continue
for real_rank in single_param_dict[param_name]:
for redundancy_rank in redundancy_ranks:
if real_rank % device_num == redundancy_rank % device_num:
open_file_id = real_rank
break
if open_file_id is not None:
output = file_dict[open_file_id].get_tensor(param_name)
saftensor_dict[param_name] = output
else:
raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
dst_stage_device_num,
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list,
origin_dst_strategy_list,
ckpt_prefix, dst_safetensors_dir, output_format,
_transform_param_list, pipe_param_list=None, file_index=None, unified_flag=False,
src_strategy_file=None):
"""
Transforms safetensors files to a specified format without using parallel processing.
"""
if src_strategy_file is not None:
from mindspore.train._utils import get_parameter_redundancy
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file)
redundancy_dict = {}
device_num = 0
for param_name, redundancy in redundancy_dict_tmp.items():
if device_num == 0:
device_num = max(max(redundancy)) + 1
origin_param_name = param_name
pipeline_stage = 0
if "-" in param_name:
pipeline_stage, origin_param_name = param_name.split("-")
pipeline_stage = int(pipeline_stage)
redundancy_new = tuple(
(tuple(x + pipeline_stage * device_num for x in subtuple)) for subtuple in redundancy)
redundancy_dict[origin_param_name] = redundancy_new
file_dict = {}
single_param_dict = {}
for file_id, _ in all_safetensor_files_map.items():
f = safe_open(all_safetensor_files_map.get(file_id), framework="np")
file_dict[file_id] = f
for param_name in f.keys():
if param_name not in single_param_dict.keys():
single_param_dict[param_name] = {file_id}
else:
single_param_dict[param_name].add(file_id)
src_strategy_list_keys = _convert_to_list(src_strategy_dict).keys() if src_strategy_dict else []
dst_strategy_list_keys = _convert_to_list(dst_strategy_dict).keys() if dst_strategy_dict else []
for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
needed_rank_list = needed_rank_list_key.split("-")
for needed_rank in needed_rank_list:
if pipe_param_list:
saftensor_dict = dict()
if src_strategy_file is not None:
_find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict,
redundancy_dict, needed_rank, device_num)
else:
with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
if not unified_flag:
all_param_name_set = set(f.keys())
src_param_name_set = set(src_strategy_list_keys)
dst_param_name_set = set(dst_strategy_list_keys)
hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
pipe_param_list.extend(list(hyper_param_set))
for param_name in pipe_param_list:
if param_name not in f.keys():
# param not in ckpt file, check reason
continue
output = f.get_tensor(param_name)
saftensor_dict[param_name] = output
else:
saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
for param_name, param in saftensor_dict.items():
src_rank = int(needed_rank) % src_stage_device_num
param_total_dict[param_name][src_rank] = param
param_attr_dict[param_name][src_rank] = (True, False)
for transform_rank in transform_rank_list:
param_total_dict_keys = list(param_total_dict.keys())
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_dict,
dst_strategy_dict)
# cut the parameter not in the pipeline stage.
for param in list(param_total_dict.keys()):
if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list):
param_total_dict_keys.remove(param)
local_rank_id = transform_rank % dst_stage_device_num
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
param_attr_dict, src_strategy_list, dst_strategy_list,
param_total_dict_keys, src_strategy_file)
if file_index is not None:
save_safetensor_file = f"part{file_index}.{output_format}"
save_safetensor_file_dir = dst_safetensors_dir
else:
save_safetensor_file = f"{ckpt_prefix}{transform_rank}.{output_format}"
save_safetensor_file_dir = os.path.join(dst_safetensors_dir, "rank_{}".format(transform_rank))
if not os.path.exists(save_safetensor_file_dir):
_make_dir(save_safetensor_file_dir, "path")
save_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
if _transform_param_list is not None:
_transform_param_list.append({save_file_name: transform_param_dict})
else:
if output_format == "safetensors":
save_file(transform_param_dict, save_file_name)
else:
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
transform_func=lambda v, name: ms.Parameter(v,
name=name))
ms.save_checkpoint(transform_param_dict, save_file_name)
del param_total_dict_keys
del param_total_dict
def _save_final_safetensors(_transform_param_list, output_format):
"""save file with list"""
new_transform_dict = {}
for transform_dict in _transform_param_list:
for save_file_name, transform_param_dict in transform_dict.items():
if save_file_name not in new_transform_dict:
new_transform_dict[save_file_name] = transform_param_dict
else:
new_transform_dict[save_file_name].update(transform_param_dict)
for save_file_name, transform_param_dict in new_transform_dict.items():
if output_format == "safetensors":
save_file(transform_param_dict, save_file_name)
else:
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
transform_func=lambda v, name: ms.Parameter(v, name=name))
ms.save_checkpoint(transform_param_dict, save_file_name)
def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckpt_prefix,
src_strategy_file,
dst_strategy_file=None):
"""Transform safetensor for stage in src_strategy_file"""
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
param_type_dict = defaultdict(dict)
src_strategy_list, dst_strategy_list, stage_id = _extract_src_dst_layout_map_by_src(src_strategy_file, \
dst_strategy_file)
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
is not None else 1
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
is not None else 1
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
safetensor_files_map = {}
src_rank_id_start = stage_id * src_stage_device_num
for local_rank in range(src_stage_device_num):
rank_id = src_rank_id_start + local_rank
safetensor_file_name = os.path.join(src_safetensors_dir, "rank_{}".format(rank_id), "*.safetensors")
rank_ckpts = glob.glob(safetensor_file_name)
rank_ckpts.sort()
for safetensor_file in rank_ckpts:
if not os.path.isfile(safetensor_file):
continue
safetensor_files_map[rank_id] = safetensor_file
for rank, local_file in safetensor_files_map.items():
if not os.path.exists(local_file):
raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
for rank, file_name in safetensor_files_map.items():
saftensor_dict = load_file(file_name)
for param_name, param in saftensor_dict.items():
# cut the parameter not in the pipeline stage.
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
continue
src_rank = rank % src_stage_device_num
param_type_dict[param_name][src_rank] = str(param.data.dtype)
param_total_dict[param_name][src_rank] = param
param_attr_dict[param_name][src_rank] = (True, False)
for local_rank_id in range(dst_stage_device_num):
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
param_attr_dict, src_strategy_list, dst_strategy_list,
param_type_dict)
save_safetensor_file = "{}{}_part{}.safetensors".format(ckpt_prefix, local_rank_id, stage_id)
save_safetensor_file_dir = os.path.join(dst_safetensors_dir, "rank_{}".format(local_rank_id))
if not os.path.exists(save_safetensor_file_dir):
_make_dir(save_safetensor_file_dir, "path")
save_safetensor_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
save_file(transform_param_dict, save_safetensor_file_name)
def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor_file_name,
src_strategy_file=None, dst_strategy_file=None):
"""
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank.
"""
if not isinstance(safetensor_files_map, dict):
raise TypeError("The safetensor_files_map should be a dict.")
if not isinstance(rank_id, int):
raise TypeError("The rank_id should be a int.")
if not isinstance(save_safetensor_file_name, str):
raise TypeError("The save_safetensor_file_name should be a str.")
if not save_safetensor_file_name.endswith(".safetensors"):
raise ValueError(
"The save_safetensor_file_name {} should end with .safetensors".format(save_safetensor_file_name))
if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
os.path.dirname(dst_strategy_file)):
raise ValueError("The director of dst_strategy_file: {} is not exists.".
format(os.path.dirname(dst_strategy_file)))
for rank, local_file in safetensor_files_map.items():
if not os.path.exists(local_file):
raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
param_type_dict = defaultdict(dict)
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
# src rank => local rank inside pipeline stage
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
is not None else 1
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
is not None else 1
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
for rank, file_name in safetensor_files_map.items():
saftensor_dict = load_file(file_name)
for param_name, param in saftensor_dict.items():
# cut the parameter not in the pipeline stage.
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
continue
src_rank = rank % src_stage_device_num
param_type_dict[param_name][src_rank] = str(param.data.dtype)
# if param.data.dtype == mstype.bfloat16:
# param.set_dtype(mstype.float32)
param_total_dict[param_name][src_rank] = param
param_attr_dict[param_name][src_rank] = (True, False)
local_rank_id = rank_id % dst_stage_device_num
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
param_attr_dict, src_strategy_list, dst_strategy_list,
param_type_dict)
save_file(transform_param_dict, save_safetensor_file_name)
def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
"""
Collects all safetensors files from the specified directory and its subdirectories.
"""
if os.path.isfile(src_safetensors_dir) and format == 'safetensors' and src_safetensors_dir.endswith('safetensors'):
return {0: src_safetensors_dir}
safetensors_rank_dir_list = os.path.join(src_safetensors_dir, "rank_[0-9]*")
all_safetensor_files_map = {}
for safetensor_dir in glob.glob(safetensors_rank_dir_list):
if not os.path.isdir(safetensor_dir):
ms.log.warning("{} is not a directory.".format(safetensor_dir))
continue
rank_id_str = safetensor_dir.split('rank_')[-1]
if not rank_id_str.isdigit():
ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....".
format(safetensor_dir))
continue
rank_id = int(rank_id_str)
if file_suffix is None:
safetensor_file_name = os.path.join(safetensor_dir, f"*.{format}")
else:
safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
rank_ckpts = glob.glob(safetensor_file_name)
rank_ckpts.sort(key=os.path.getctime)
if rank_ckpts:
all_safetensor_files_map[rank_id] = rank_ckpts[-1]
return all_safetensor_files_map
def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
"""
Identifies the ranks needed for transformation based on source and destination strategies.
"""
needed_rank_list_map = defaultdict(list)
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
dst_device_num = dst_stage_device_num * dst_stage_num
for rank in _progress_bar(range(dst_device_num)):
needed_rank_list = ms.rank_list_for_transform(rank, src_strategy_dict, dst_strategy_dict)
needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
needed_rank_list_map[needed_rank_list_key].append(rank)
return needed_rank_list_map
def load_file_by_param_name(filename, parme_name_list):
result = {}
with safe_open(filename, framework="np") as f:
for k in parme_name_list:
result[k] = f.get_tensor(k)
return result
def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None):
"""
Transform model parallel dimension for distributed safetensor files.
"""
transform_param_dict = {}
device_num = -1
param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
for param_name in param_total_dict_keys:
tensor_shape = list(param_total_dict[param_name].values())[0].shape
from_dev_matrix = [1]
from_tensor_map = [-1] * len(tensor_shape)
from_opt_shard_step = 0
from_opt_shard_size = 0
if src_strategy_list is not None:
if param_name not in src_strategy_list:
continue
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
src_strategy_list.get(param_name))
to_dev_matrix_origin = [1]
to_tensor_map_origin = [-1] * len(tensor_shape)
to_opt_shard_step = 0
to_opt_shard_size = 0
if dst_strategy_list is not None:
if param_name not in dst_strategy_list:
continue
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
dst_strategy_list.get(param_name))
# Add optimizer sharding dim for tensor layout
device_num = np.prod(from_dev_matrix)
if device_num < 1:
raise ValueError("None of the parameters in safetensor file are in either src strategy or "
"dst strategy. Please check correctness of strategy files. "
"Param name is: {}, rank_id is {}.".format(param_name, rank_id))
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
origin_tensor_shape = ()
for i, item in enumerate(tensor_shape):
if i == 0 and from_opt_shard_size > 0:
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
continue
origin_tensor_shape += (item * param_strategy[i],)
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
# Convert tensor layout to same device num
from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
from_tensor_map, to_full_tensor_shape,
to_dev_matrix, to_tensor_map)
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
device_list = list(range(0, np.prod(from_tensor_layout[0])))
if rank_id % device_num not in param_attr_dict[param_name] and src_strategy_file is None:
raise ValueError("The safetensor of rank {} is missing.".format(rank_id % device_num))
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
device_list, rank_id)
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
param_total_dict_copy = param_total_dict[param_name].copy()
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
# Handle those parameter like learning_rate, global_step which not in strategy_file.
for param_name in param_total_dict_keys:
if param_name not in transform_param_dict:
transform_para = param_total_dict[param_name][rank_id % device_num]
transform_param_dict[param_name] = transform_para
return transform_param_dict
[docs]def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
max_process_num=64):
"""
Merge multiple safetensor files into a unified safetensor file.
Args:
src_dir (str): Source weight saving directory.
src_strategy_file (str): Source weight segmentation strategy file.
dst_dir (str): Target save directory.
merge_with_redundancy (bool, optional): Whether the merged source weight files are de-duplicated and
saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
meaning all safetensors files in the source weight directory will be merged.
max_process_num (int): Maximum number of processes. Default: 64.
Raises:
ValueError: If the safetensors file of rank is missing.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> src_dir = "/usr/safetensors/llama31B/4p_safetensors/"
>>> src_strategy_file = "/usr/safetensors/llama31B/strategy_4p.ckpt"
>>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
>>> ms.unified_safetensors(src_dir, src_strategy_file, dst_dir)
"""
_check_transform_safetensors(src_dir, "", src_strategy_file, None)
_make_dir(dst_dir, "path")
if os.path.isfile(src_dir):
raise ValueError("For 'unified_safetensors', the 'src_dir' can not be a file.")
all_safetensor_files_map = _collect_safetensor_files(src_dir, format="safetensors", file_suffix=file_suffix)
all_ckpt_files_map = _collect_safetensor_files(src_dir, format="ckpt", file_suffix=file_suffix)
if all_safetensor_files_map and all_ckpt_files_map:
raise ValueError("For 'unified_safetensors', the 'src_dir' cannot contain "
"both ckpt file and safetensors file simultaneously")
src_strategy_dict = _build_searched_strategy(src_strategy_file)
src_stage_device_num = _get_device_num_from_strategy(src_strategy_dict)
dst_stage_device_num = 1
origin_src_strategy_list = _extract_layout_map(src_strategy_dict)
origin_dst_strategy_list = None
needed_rank_list_map = _find_needed_ranks(src_strategy_dict, dst_strategy_dict=None)
for needed_rank_list, rank in needed_rank_list_map.items():
for needed_rank in needed_rank_list.split("-"):
if int(needed_rank) not in all_safetensor_files_map:
raise ValueError("The safetensor file of rank{} is needed for converting rank{}'s safetensor, "
"but it is missing.".format(needed_rank, rank))
layout_map = _convert_to_list(src_strategy_dict)
total_size = 0
actual_params = set()
for _, file_name in all_safetensor_files_map.items():
total_size += os.path.getsize(file_name) / 1024 / 1024 / 1024
with safe_open(file_name, framework="np") as f:
actual_params.update(f.keys())
split_num = math.ceil(total_size / 3)
params_to_store = actual_params & set(layout_map.keys())
name_list = []
for name in list(params_to_store):
if name.startswith("accu_grads"):
continue
name_list.append(name)
split_list = _split_list(name_list, split_num)
with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
all_key = f.keys()
hyper_parameter = set(all_key) - set(name_list)
if hyper_parameter:
hyper_dict = {}
for key in hyper_parameter:
hyper_dict[key] = f.get_tensor(key)
save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
# save parameter map json
param_name_dict = dict()
for index, part_list in enumerate(split_list):
for name in part_list:
param_name_dict[name] = f"part{index}.safetensors"
json_str = json.dumps(param_name_dict, indent=4)
map_file = os.path.join(dst_dir, "param_name_map.json")
with open(map_file, 'w') as f:
f.write(json_str)
max_process = min(split_num, max_process_num)
res = [i for i in range(split_num)]
res = _split_list(res, max_process)
processes = []
src_strategy_name = None
if not merge_with_redundancy:
src_strategy_name = src_strategy_file
for i in range(max_process):
p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
"", dst_dir, "safetensors", None, split_list, res[i], True, src_strategy_name))
p.start()
processes.append(p)
for p in processes:
p.join()
def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map,
src_stage_device_num,
dst_stage_device_num,
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list,
origin_dst_strategy_list,
ckpt_prefix, dst_safetensors_dir, output_format,
_transform_param_list, pipe_param_list=None, file_index=None,
unified_flag=False, src_strategy_file=None):
for i in file_index:
_transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
origin_src_strategy_list,
origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir, output_format,
_transform_param_list, pipe_param_list[i], i, unified_flag, src_strategy_file)
def _split_list(split_list, split_num):
split_array = np.array_split(split_list, split_num)
return [array.tolist() for array in split_array]
def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num):
"""apply safetensors object operators"""
if not transform_operator_stack:
return sf_obj[:]
level = transform_operator_stack[-1][1]
level_operators = []
while True:
if not transform_operator_stack or (level != transform_operator_stack[-1][1]):
tmp_tensor_dict = {}
if not level_operators:
continue
op_name = level_operators[0][2][0]
for operator_pair in level_operators:
rank_id = operator_pair[0]
cur_level = operator_pair[1]
operator = operator_pair[2]
if operator[0] != op_name:
raise ValueError("The operator in the same level should be equal in the transform tensor operator "
"list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
if operator[0] != "AllConcat":
sf_obj = _apply_operator(operator[0])(sf_obj, operator)
continue
for rank in operator[1][:-1]:
if rank % device_num not in sf_obj:
raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
allgather_list = [sf_obj for _ in operator[1][:-1]]
tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
if op_name == "AllConcat":
for rank, value in tmp_tensor_dict.items():
sf_obj = value
level_operators.clear()
if not transform_operator_stack:
break
operator_pair = transform_operator_stack.pop()
level = operator_pair[1]
level_operators.append(operator_pair)
return sf_obj
def _check_name_map_value_is_str(value):
"""check input is bool"""
if not isinstance(value, str):
raise ValueError(
f"For 'load_distributed_checkpoint', the value of name_map must be str, but got {type(value)}.")
def _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param):
"""process hyper params"""
if 'hyper_param.safetensors' in file_list:
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
with safe_open(hyper_parameter_file_name, framework="np") as f:
for key in f.keys():
cur_param_name = name_map.get(key) if name_map is not None and key in name_map else key
_check_name_map_value_is_str(cur_param_name)
total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
return total_param
def _load_parallel_checkpoint(file_info):
"""load parallel safetensors by merged file."""
total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, rank_id, output_format, name_map = file_info
file_list = os.listdir(total_safetensors_dir)
json_files = [file for file in file_list if file.endswith('.json')]
if len(file_list) == 1:
logger.info("There is only one weight file in the directory, which will be automatically mapped.")
file_name = os.path.join(total_safetensors_dir, file_list[0])
is_file = os.path.isfile(file_name)
if not is_file:
raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
f"in the `unified_safetensors_dir`.")
with safe_open(file_name, framework="np") as f:
keys = f.keys()
values = len(keys) * [file_list[0]]
param_name_map = dict(zip(keys, values))
else:
if len(json_files) != 1:
raise ValueError(f"For 'load_parallel_checkpoint', the number of json files in 'total_safetensors_dir' "
f"must be 1, but got {len(json_files)}.")
param_name_json = os.path.join(total_safetensors_dir, json_files[0])
with open(param_name_json, 'r') as f:
param_name_map = json.load(f)
if dst_strategy_file is not None:
_, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
param_list = dst_strategy_list.keys()
else:
dst_strategy_list = None
param_list = param_name_map.keys()
total_param = dict()
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
is not None else 1
local_rank_id = rank_id % dst_stage_device_num
for param_name in param_list:
if param_name not in param_name_map:
continue
file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
with safe_open(file_name, framework="np") as f:
if param_name not in f.keys():
continue
sf_obj = f.get_slice(param_name)
tensor_shape = sf_obj.get_shape()
from_dev_matrix = [1]
from_tensor_map = [-1] * len(tensor_shape)
from_opt_shard_step = 0
from_opt_shard_size = 0
if dst_strategy_list is not None:
if param_name not in dst_strategy_list:
continue
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
dst_strategy_list.get(param_name))
device_num = np.prod(from_dev_matrix)
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
origin_tensor_shape = ()
for i, item in enumerate(tensor_shape):
if i == 0 and from_opt_shard_size > 0:
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
continue
origin_tensor_shape += (item * param_strategy[i],)
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
# Convert tensor layout to same device num
from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
from_dev_matrix,
from_tensor_map,
to_full_tensor_shape,
to_dev_matrix, to_tensor_map)
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
device_list = list(range(0, np.prod(from_tensor_layout[0])))
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
device_list, local_rank_id)
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
else:
slice_param = sf_obj[:]
cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
_check_name_map_value_is_str(cur_param_name)
total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
total_param = _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param)
if net is not None:
param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
return param_not_load, ckpt_not_load
_make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.{output_format}"),
format=output_format)
return None
def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
"""get slice op"""
tensor_shape = sf_obj.get_shape()
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
dst_strategy_list.get(param_name))
# Add optimizer sharding dim for tensor layout
to_dev_matrix, to_tensor_map, _ = _construct_tensor_layout_for_opt_shard(
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, tensor_shape)
slice_op = _load_tensor_shape(to_dev_matrix, to_tensor_map, full_shape=tensor_shape, rank_id=rank_id)
shape = None
if to_opt_shard_size > 0:
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
to_slice_tensor_shape = ()
for i, item in enumerate(tensor_shape):
if i == 0 and to_opt_shard_size > 0:
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
continue
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
shape = list(to_slice_tensor_shape)
return slice_op, shape
__all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
"transform_safetensors_by_rank", "unified_safetensors"]