Source code for mindspore_federated.data_join.io

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export and load data in vfl."""

import os
from joblib import Parallel, delayed
from mindspore.mindrecord import FileWriter
from mindspore import dataset as ds


def _split_keys(keys, shard_num):
    """
    Split keys.
    """
    keys_list = list()
    keys_len = len(keys)
    before_len = keys_len // shard_num
    after_len = before_len + 1
    change_index = shard_num - keys_len % shard_num - 1
    start_index = 0
    length = before_len
    for shard_index in range(shard_num):
        keys_list.append(keys[start_index:(start_index+length)])
        start_index += length
        length = before_len if shard_index < change_index else after_len
    return keys_list


def export_mindrecord(file_name, raw_data, keys, shard_num=1, overwrite=True):
    """
    Export mindrecord.
    """
    schema = raw_data.schema()
    desc = raw_data.desc()
    if shard_num == 1:
        export_single_mindrecord(file_name, raw_data, keys, schema, desc, overwrite)
    else:
        export_multi_mindrecord(file_name, raw_data, keys, schema, desc, shard_num, overwrite)


def export_single_mindrecord(file_name, raw_data, keys, schema, desc, overwrite=True):
    """
    Export single mindrecord.
    """
    if not keys:
        return
    writer = FileWriter(file_name=file_name, shard_num=1, overwrite=overwrite)
    writer.add_schema(schema, desc)
    writer.open_and_set_header()
    for value in raw_data.values(keys=keys):
        writer.write_raw_data([value])
    writer.commit()


def export_multi_mindrecord(file_name, raw_data, keys, schema, desc, shard_num, overwrite=True):
    """
    Export multi mindrecord.
    """
    keys_list = _split_keys(keys, shard_num)
    Parallel(n_jobs=shard_num)(
        delayed(export_single_mindrecord)(file_name + str(shard_id), raw_data, keys, schema, desc, overwrite)
        for shard_id, keys in enumerate(keys_list)
    )


[docs]def load_mindrecord(input_dir, seed=0, **kwargs): """ Load MindRecord files. Args: input_dir (str): Input directory for storing MindRecord-related files. seed (int): The random seed. Default: 0. Returns: MindDataset, Order-preserving datasets. Note: This API transparently transfers the `kwargs` to MindDataset. For details about more hyper parameters in `kwargs`, refer to `mindspore.dataset.MindDataset` . Examples: >>> dataset = load_mindrecord(input_dir="input_dir", seed=0, shuffle=True) >>> for batch in dataset.create_tuple_iterator(): ... print(batch) """ ds.config.set_seed(seed) dataset_files = [os.path.join(input_dir, _) for _ in os.listdir(input_dir) if "db" not in _] dataset_files.sort() dataset = ds.MindDataset(dataset_files, **kwargs) return dataset