# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Functions to support dataset serialize and deserialize.
"""
import json
import os
import sys
import mindspore.common.dtype as mstype
from mindspore import log as logger
from . import datasets as de
from ..vision.utils import Inter, Border, ImageBatchFormat
[docs]def serialize(dataset, json_filepath=""):
"""
Serialize dataset pipeline into a JSON file.
Note:
Currently some Python objects are not supported to be serialized.
For Python function serialization of map operator, de.serialize will only return its function name.
Args:
dataset (Dataset): The starting node.
json_filepath (str): The filepath where a serialized JSON file will be generated.
Returns:
Dict, The dictionary contains the serialized dataset graph.
Raises:
OSError: Can not open a file
Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> # serialize it to JSON file
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict
"""
return dataset.to_json(json_filepath)
[docs]def deserialize(input_dict=None, json_filepath=None):
"""
Construct a de pipeline from a JSON file produced by de.serialize().
Note:
Currently Python function deserialization of map operator are not supported.
Args:
input_dict (dict): A Python dictionary containing a serialized dataset graph.
json_filepath (str): A path to the JSON file.
Returns:
de.Dataset or None if error occurs.
Raises:
OSError: Can not open the JSON file.
Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> # Use case 1: to/from JSON file
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> # Use case 2: to/from Python dictionary
>>> serialized_data = ds.engine.serialize(dataset)
>>> dataset = ds.engine.deserialize(input_dict=serialized_data)
"""
data = None
if input_dict:
data = construct_pipeline(input_dict)
if json_filepath:
dict_pipeline = dict()
real_file_path = os.path.realpath(json_filepath)
with open(real_file_path, 'r') as json_file:
dict_pipeline = json.load(json_file)
data = construct_pipeline(dict_pipeline)
return data
def expand_path(node_repr, key, val):
"""Convert relative to absolute path."""
if isinstance(val, list):
node_repr[key] = [os.path.abspath(file) for file in val]
else:
node_repr[key] = os.path.abspath(val)
[docs]def show(dataset, indentation=2):
"""
Write the dataset pipeline graph to logger.info file.
Args:
dataset (Dataset): The starting node.
indentation (int, optional): The indentation used by the JSON print.
Do not indent if indentation is None.
Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
>>> one_hot_encode = c_transforms.OneHot(10)
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> ds.show(dataset)
"""
pipeline = dataset.to_json()
logger.info(json.dumps(pipeline, indent=indentation))
[docs]def compare(pipeline1, pipeline2):
"""
Compare if two dataset pipelines are the same.
Args:
pipeline1 (Dataset): a dataset pipeline.
pipeline2 (Dataset): a dataset pipeline.
Returns:
Whether pipeline1 is equal to pipeline2.
Examples:
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
>>> ds.compare(pipeline1, pipeline2)
"""
return pipeline1.to_json() == pipeline2.to_json()
def construct_pipeline(node):
"""Construct the Python Dataset objects by following the dictionary deserialized from JSON file."""
op_type = node.get('op_type')
if not op_type:
raise ValueError("op_type field in the json file can't be None.")
# Instantiate Python Dataset object based on the current dictionary element
dataset = create_node(node)
# Initially it is not connected to any other object.
dataset.children = []
# Construct the children too and add edge between the children and parent.
for child in node['children']:
dataset.children.append(construct_pipeline(child))
return dataset
def create_node(node):
"""Parse the key, value in the node dictionary and instantiate the Python Dataset object"""
logger.info('creating node: %s', node['op_type'])
dataset_op = node['op_type']
op_module = "mindspore.dataset"
# Get the Python class to be instantiated.
# Example:
# "op_type": "MapDataset",
# "op_module": "mindspore.dataset.datasets",
if node.get("children"):
pyclass = getattr(sys.modules[op_module], "Dataset")
else:
pyclass = getattr(sys.modules[op_module], dataset_op)
pyobj = None
# Find a matching Dataset class and call the constructor with the corresponding args.
# When a new Dataset class is introduced, another if clause and parsing code needs to be added.
# Dataset Source Ops (in alphabetical order)
pyobj = create_dataset_node(pyclass, node, dataset_op)
if not pyobj:
# Dataset Ops (in alphabetical order)
pyobj = create_dataset_operation_node(node, dataset_op)
return pyobj
def create_dataset_node(pyclass, node, dataset_op):
"""Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object"""
pyobj = None
if dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), node.get('usage'),
sampler, node.get('decode'), node.get('extensions'), num_samples, node.get('num_shards'),
node.get('shard_id'))
elif dataset_op == 'Cifar10Dataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'Cifar100Dataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'ClueDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('task'),
node.get('usage'), num_samples, node.get('num_parallel_workers'), shuffle,
node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'CocoDataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), num_samples,
node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler,
node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'CSVDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('field_delim'),
node.get('column_defaults'), node.get('column_names'), num_samples,
node.get('num_parallel_workers'), shuffle,
node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'ImageFolderDataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], num_samples, node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('extensions'),
node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
node.get('shard_id'))
elif dataset_op == 'ManifestDataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_file'], node['usage'], num_samples,
node.get('num_parallel_workers'), node.get('shuffle'), sampler,
node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
node.get('shard_id'))
elif dataset_op == 'MnistDataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'TextFileDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], num_samples,
node.get('num_parallel_workers'), shuffle,
node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'TFRecordDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'),
num_samples, node.get('num_parallel_workers'),
shuffle, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'VOCDataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('usage'), node.get('class_indexing'),
num_samples, node.get('num_parallel_workers'), node.get('shuffle'),
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
return pyobj
def create_dataset_operation_node(node, dataset_op):
"""Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object"""
pyobj = None
if dataset_op == 'Batch':
pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
elif dataset_op == 'Map':
tensor_ops = construct_tensor_ops(node.get('operations'))
pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'),
node.get('column_order'), node.get('num_parallel_workers'),
False, None, node.get('callbacks'))
elif dataset_op == 'Project':
pyobj = de.Dataset().project(node['columns'])
elif dataset_op == 'Rename':
pyobj = de.Dataset().rename(node['input_columns'], node['output_columns'])
elif dataset_op == 'Repeat':
pyobj = de.Dataset().repeat(node.get('count'))
elif dataset_op == 'Shuffle':
pyobj = de.Dataset().shuffle(node.get('buffer_size'))
elif dataset_op == 'Skip':
pyobj = de.Dataset().skip(node.get('count'))
elif dataset_op == 'Take':
pyobj = de.Dataset().take(node.get('count'))
elif dataset_op == 'Transfer':
pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue'))
elif dataset_op == 'Zip':
# Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller.
pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
else:
raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().")
return pyobj
def construct_sampler(in_sampler):
"""Instantiate Sampler object based on the information from dictionary['sampler']"""
sampler = None
if in_sampler is not None:
if "num_samples" in in_sampler:
num_samples = check_and_replace_input(in_sampler['num_samples'], 0, None)
sampler_name = in_sampler['sampler_name']
sampler_module = "mindspore.dataset"
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
if sampler_name == 'DistributedSampler':
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
elif sampler_name == 'PKSampler':
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
elif sampler_name == 'RandomSampler':
sampler = sampler_class(in_sampler.get('replacement'), num_samples)
elif sampler_name == 'SequentialSampler':
sampler = sampler_class(in_sampler.get('start_index'), num_samples)
elif sampler_name == 'SubsetRandomSampler':
sampler = sampler_class(in_sampler['indices'], num_samples)
elif sampler_name == 'WeightedRandomSampler':
sampler = sampler_class(in_sampler['weights'], num_samples, in_sampler.get('replacement'))
else:
raise ValueError("Sampler type is unknown: {}.".format(sampler_name))
if in_sampler.get("child_sampler"):
for child in in_sampler["child_sampler"]:
sampler.add_child(construct_sampler(child))
return sampler
def construct_tensor_ops(operations):
"""Instantiate tensor op object(s) based on the information from dictionary['operations']"""
result = []
for op in operations:
op_name = op.get('tensor_op_name')
op_params = op.get('tensor_op_params')
if op.get('is_python_front_end_op'): # check if it's a py_transform op
raise NotImplementedError("python function is not yet supported by de.deserialize().")
if op_name == "HwcToChw":
op_name = "HWC2CHW"
if op_name == "UniformAug":
op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name):
op_class = getattr(op_module_trans, op_name, None)
else:
raise RuntimeError(op_name + " is not yet supported by deserialize().")
if op_params is None: # If no parameter is specified, call it directly
result.append(op_class())
else:
# Input parameter type cast
for key, val in op_params.items():
if key in ['center', 'fill_value']:
op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
result.append(op_class(**op_params))
return result
def to_policy(op_list):
""" op_list to policy """
policy_tensor_ops = []
for policy_list in op_list:
sub_policy_tensor_ops = []
for policy_item in policy_list:
sub_policy_tensor_ops.append(
(construct_tensor_ops(policy_item.get('tensor_op')), policy_item.get('prob')))
policy_tensor_ops.append(sub_policy_tensor_ops)
return policy_tensor_ops
def to_shuffle_mode(shuffle):
""" int to shuffle mode """
ret_val = False
if shuffle == 2:
ret_val = "global"
elif shuffle == 1:
ret_val = "files"
return ret_val
def to_interpolation_mode(inter):
""" int to interpolation mode """
return {
0: Inter.LINEAR,
1: Inter.NEAREST,
2: Inter.CUBIC,
3: Inter.AREA
}[inter]
def to_border_mode(border):
""" int to border mode """
return {
0: Border.CONSTANT,
1: Border.EDGE,
2: Border.REFLECT,
3: Border.SYMMETRIC
}[border]
def to_mstype(data_type):
""" str to mstype """
return {
"bool": mstype.bool_,
"int8": mstype.int8,
"int16": mstype.int16,
"int32": mstype.int32,
"int64": mstype.int64,
"uint8": mstype.uint8,
"uint16": mstype.uint16,
"uint32": mstype.uint32,
"uint64": mstype.uint64,
"float16": mstype.float16,
"float32": mstype.float32,
"float64": mstype.float64,
"string": mstype.string
}[data_type]
def to_image_batch_format(image_batch_format):
""" int to image batch format """
return {
0: ImageBatchFormat.NHWC,
1: ImageBatchFormat.NCHW
}[image_batch_format]
def check_and_replace_input(input_value, expect, replace):
""" check and replace input arg """
return replace if input_value == expect else input_value