Source code for mindspore.dataset.engine.serializer_deserializer

# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Functions to support dataset serialize and deserialize.
"""
import json
import os
import sys

import mindspore.common.dtype as mstype
from mindspore import log as logger
from . import datasets as de
from ..vision.utils import Inter, Border, ImageBatchFormat


[docs]def serialize(dataset, json_filepath=""): """ Serialize dataset pipeline into a JSON file. Note: Currently some Python objects are not supported to be serialized. For Python function serialization of map operator, de.serialize will only return its function name. Args: dataset (Dataset): The starting node. json_filepath (str): The filepath where a serialized JSON file will be generated. Returns: Dict, The dictionary contains the serialized dataset graph. Raises: OSError: Can not open a file Examples: >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100) >>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") >>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> # serialize it to JSON file >>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") >>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict """ return dataset.to_json(json_filepath)
[docs]def deserialize(input_dict=None, json_filepath=None): """ Construct a de pipeline from a JSON file produced by de.serialize(). Note: Currently Python function deserialization of map operator are not supported. Args: input_dict (dict): A Python dictionary containing a serialized dataset graph. json_filepath (str): A path to the JSON file. Returns: de.Dataset or None if error occurs. Raises: OSError: Can not open the JSON file. Examples: >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100) >>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") >>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> # Use case 1: to/from JSON file >>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") >>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json") >>> # Use case 2: to/from Python dictionary >>> serialized_data = ds.engine.serialize(dataset) >>> dataset = ds.engine.deserialize(input_dict=serialized_data) """ data = None if input_dict: data = construct_pipeline(input_dict) if json_filepath: dict_pipeline = dict() real_file_path = os.path.realpath(json_filepath) with open(real_file_path, 'r') as json_file: dict_pipeline = json.load(json_file) data = construct_pipeline(dict_pipeline) return data
def expand_path(node_repr, key, val): """Convert relative to absolute path.""" if isinstance(val, list): node_repr[key] = [os.path.abspath(file) for file in val] else: node_repr[key] = os.path.abspath(val)
[docs]def show(dataset, indentation=2): """ Write the dataset pipeline graph to logger.info file. Args: dataset (Dataset): The starting node. indentation (int, optional): The indentation used by the JSON print. Do not indent if indentation is None. Examples: >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100) >>> one_hot_encode = c_transforms.OneHot(10) >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") >>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> ds.show(dataset) """ pipeline = dataset.to_json() logger.info(json.dumps(pipeline, indent=indentation))
[docs]def compare(pipeline1, pipeline2): """ Compare if two dataset pipelines are the same. Args: pipeline1 (Dataset): a dataset pipeline. pipeline2 (Dataset): a dataset pipeline. Returns: Whether pipeline1 is equal to pipeline2. Examples: >>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100) >>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100) >>> ds.compare(pipeline1, pipeline2) """ return pipeline1.to_json() == pipeline2.to_json()
def construct_pipeline(node): """Construct the Python Dataset objects by following the dictionary deserialized from JSON file.""" op_type = node.get('op_type') if not op_type: raise ValueError("op_type field in the json file can't be None.") # Instantiate Python Dataset object based on the current dictionary element dataset = create_node(node) # Initially it is not connected to any other object. dataset.children = [] # Construct the children too and add edge between the children and parent. for child in node['children']: dataset.children.append(construct_pipeline(child)) return dataset def create_node(node): """Parse the key, value in the node dictionary and instantiate the Python Dataset object""" logger.info('creating node: %s', node['op_type']) dataset_op = node['op_type'] op_module = "mindspore.dataset" # Get the Python class to be instantiated. # Example: # "op_type": "MapDataset", # "op_module": "mindspore.dataset.datasets", if node.get("children"): pyclass = getattr(sys.modules[op_module], "Dataset") else: pyclass = getattr(sys.modules[op_module], dataset_op) pyobj = None # Find a matching Dataset class and call the constructor with the corresponding args. # When a new Dataset class is introduced, another if clause and parsing code needs to be added. # Dataset Source Ops (in alphabetical order) pyobj = create_dataset_node(pyclass, node, dataset_op) if not pyobj: # Dataset Ops (in alphabetical order) pyobj = create_dataset_operation_node(node, dataset_op) return pyobj def create_dataset_node(pyclass, node, dataset_op): """Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object""" pyobj = None if dataset_op == 'CelebADataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), node.get('usage'), sampler, node.get('decode'), node.get('extensions'), num_samples, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'Cifar10Dataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'Cifar100Dataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'ClueDataset': shuffle = to_shuffle_mode(node.get('shuffle')) if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('task'), node.get('usage'), num_samples, node.get('num_parallel_workers'), shuffle, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'CocoDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), num_samples, node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'CSVDataset': shuffle = to_shuffle_mode(node.get('shuffle')) if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('field_delim'), node.get('column_defaults'), node.get('column_names'), num_samples, node.get('num_parallel_workers'), shuffle, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'ImageFolderDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], num_samples, node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('extensions'), node.get('class_indexing'), node.get('decode'), node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'ManifestDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_file'], node['usage'], num_samples, node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('class_indexing'), node.get('decode'), node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'MnistDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'TextFileDataset': shuffle = to_shuffle_mode(node.get('shuffle')) if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], num_samples, node.get('num_parallel_workers'), shuffle, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'TFRecordDataset': shuffle = to_shuffle_mode(node.get('shuffle')) if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), num_samples, node.get('num_parallel_workers'), shuffle, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'VOCDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('usage'), node.get('class_indexing'), num_samples, node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) return pyobj def create_dataset_operation_node(node, dataset_op): """Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object""" pyobj = None if dataset_op == 'Batch': pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) elif dataset_op == 'Map': tensor_ops = construct_tensor_ops(node.get('operations')) pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'), node.get('column_order'), node.get('num_parallel_workers'), False, None, node.get('callbacks')) elif dataset_op == 'Project': pyobj = de.Dataset().project(node['columns']) elif dataset_op == 'Rename': pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) elif dataset_op == 'Repeat': pyobj = de.Dataset().repeat(node.get('count')) elif dataset_op == 'Shuffle': pyobj = de.Dataset().shuffle(node.get('buffer_size')) elif dataset_op == 'Skip': pyobj = de.Dataset().skip(node.get('count')) elif dataset_op == 'Take': pyobj = de.Dataset().take(node.get('count')) elif dataset_op == 'Transfer': pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) elif dataset_op == 'Zip': # Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller. pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) else: raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().") return pyobj def construct_sampler(in_sampler): """Instantiate Sampler object based on the information from dictionary['sampler']""" sampler = None if in_sampler is not None: if "num_samples" in in_sampler: num_samples = check_and_replace_input(in_sampler['num_samples'], 0, None) sampler_name = in_sampler['sampler_name'] sampler_module = "mindspore.dataset" sampler_class = getattr(sys.modules[sampler_module], sampler_name) if sampler_name == 'DistributedSampler': sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle')) elif sampler_name == 'PKSampler': sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle')) elif sampler_name == 'RandomSampler': sampler = sampler_class(in_sampler.get('replacement'), num_samples) elif sampler_name == 'SequentialSampler': sampler = sampler_class(in_sampler.get('start_index'), num_samples) elif sampler_name == 'SubsetRandomSampler': sampler = sampler_class(in_sampler['indices'], num_samples) elif sampler_name == 'WeightedRandomSampler': sampler = sampler_class(in_sampler['weights'], num_samples, in_sampler.get('replacement')) else: raise ValueError("Sampler type is unknown: {}.".format(sampler_name)) if in_sampler.get("child_sampler"): for child in in_sampler["child_sampler"]: sampler.add_child(construct_sampler(child)) return sampler def construct_tensor_ops(operations): """Instantiate tensor op object(s) based on the information from dictionary['operations']""" result = [] for op in operations: op_name = op.get('tensor_op_name') op_params = op.get('tensor_op_params') if op.get('is_python_front_end_op'): # check if it's a py_transform op raise NotImplementedError("python function is not yet supported by de.deserialize().") if op_name == "HwcToChw": op_name = "HWC2CHW" if op_name == "UniformAug": op_name = "UniformAugment" op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] if hasattr(op_module_vis, op_name): op_class = getattr(op_module_vis, op_name, None) elif hasattr(op_module_trans, op_name): op_class = getattr(op_module_trans, op_name, None) else: raise RuntimeError(op_name + " is not yet supported by deserialize().") if op_params is None: # If no parameter is specified, call it directly result.append(op_class()) else: # Input parameter type cast for key, val in op_params.items(): if key in ['center', 'fill_value']: op_params[key] = tuple(val) elif key in ['interpolation', 'resample']: op_params[key] = Inter(to_interpolation_mode(val)) elif key in ['padding_mode']: op_params[key] = Border(to_border_mode(val)) elif key in ['data_type']: op_params[key] = to_mstype(val) elif key in ['image_batch_format']: op_params[key] = to_image_batch_format(val) elif key in ['policy']: op_params[key] = to_policy(val) elif key in ['transform', 'transforms']: op_params[key] = construct_tensor_ops(val) result.append(op_class(**op_params)) return result def to_policy(op_list): """ op_list to policy """ policy_tensor_ops = [] for policy_list in op_list: sub_policy_tensor_ops = [] for policy_item in policy_list: sub_policy_tensor_ops.append( (construct_tensor_ops(policy_item.get('tensor_op')), policy_item.get('prob'))) policy_tensor_ops.append(sub_policy_tensor_ops) return policy_tensor_ops def to_shuffle_mode(shuffle): """ int to shuffle mode """ ret_val = False if shuffle == 2: ret_val = "global" elif shuffle == 1: ret_val = "files" return ret_val def to_interpolation_mode(inter): """ int to interpolation mode """ return { 0: Inter.LINEAR, 1: Inter.NEAREST, 2: Inter.CUBIC, 3: Inter.AREA }[inter] def to_border_mode(border): """ int to border mode """ return { 0: Border.CONSTANT, 1: Border.EDGE, 2: Border.REFLECT, 3: Border.SYMMETRIC }[border] def to_mstype(data_type): """ str to mstype """ return { "bool": mstype.bool_, "int8": mstype.int8, "int16": mstype.int16, "int32": mstype.int32, "int64": mstype.int64, "uint8": mstype.uint8, "uint16": mstype.uint16, "uint32": mstype.uint32, "uint64": mstype.uint64, "float16": mstype.float16, "float32": mstype.float32, "float64": mstype.float64, "string": mstype.string }[data_type] def to_image_batch_format(image_batch_format): """ int to image batch format """ return { 0: ImageBatchFormat.NHWC, 1: ImageBatchFormat.NCHW }[image_batch_format] def check_and_replace_input(input_value, expect, replace): """ check and replace input arg """ return replace if input_value == expect else input_value