mindflow.data.dataset 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#pylint: disable=W0223
#pylint: disable=W0221
#pylint: disable=W0212
"""
Combine pde/ic/bc datasets together
"""
from __future__ import absolute_import
import copy

import mindspore.dataset as ds
from mindspore import log as logger

from .data_base import Data, ExistedDataConfig
from .existed_data import ExistedDataset
from .equation import Equation
from .boundary import BoundaryIC, BoundaryBC
from ..geometry import Geometry
from ..utils.check_func import check_param_type, check_dict_type_value

_geomdata_dict = {
    "domain": Equation,
    "IC": BoundaryIC,
    "BC": BoundaryBC,
}


[文档]class Dataset(Data): r""" Combine datasets together. Parameters: geometry_dict (dict, optional): specifies geometry datasets to be merged. The key is geometry instance and value is a list of type of geometry. For example, geometry_dict = {geom : ["domain", "BC", "IC"]}. Default: ``None``. existed_data_list (Union[list, tuple, ExistedDataConfig], optional): specifies existed datasets to be merged. For example, existed_data_list = [ExistedDataConfig_Instance1, ExistedDataConfig_Instance2]. Default: ``None``. dataset_list (Union[list, tuple, Data], optional): specifies instances of data to be merged. For example, dataset_list=[BoundaryIC_Instance, Equation_Instance, BoundaryBC_Instance and ExistedData_Instance]. Default: ``None``. Raises: ValueError: If `geometry_dict`, existed_data_list and dataset_list are all ``None``. TypeError: If the type of `geometry_dict` is not dict. TypeError: If the type of key of geometry_dict is not instance of class Geometry. TypeError: If the type of `existed_data_list` is not list, tuple or instance of ExistedDataConfig. TypeError: If the element of `existed_data_list` is not instance of ExistedDataConfig. TypeError: If the element of `dataset_list` is not instance of class Data. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindflow.geometry import Rectangle, generate_sampling_config >>> from mindflow.data import Dataset >>> rectangle_mesh = dict({'domain': dict({'random_sampling': False, 'size': [50, 25]})}) >>> rect_space = Rectangle("rectangle", coord_min=[0, 0], coord_max=[5, 5], ... sampling_config=generate_sampling_config(rectangle_mesh)) >>> geom_dict = {rect_space: ["domain"]} >>> dataset = Dataset(geometry_dict=geom_dict) """ def __init__(self, geometry_dict=None, existed_data_list=None, dataset_list=None): super(Dataset, self).__init__() if all((geometry_dict is None, existed_data_list is None, dataset_list is None)): raise ValueError( "Dataset should have at least one sub-dataset, but got None") if geometry_dict is not None: check_param_type(geometry_dict, "geometry_dict", data_type=dict) check_dict_type_value(geometry_dict, "geometry_dict", key_type=Geometry, value_type=str, value_value=list(_geomdata_dict.keys())) if existed_data_list is not None: if isinstance(existed_data_list, ExistedDataConfig): existed_data_list = [existed_data_list] check_param_type(existed_data_list, "existed_data_list", (list, tuple)) for data_config in existed_data_list: check_param_type( data_config, "element in existed_data_list", ExistedDataConfig) if dataset_list is not None: if isinstance(dataset_list, Data): dataset_list = [dataset_list] check_param_type(dataset_list, "dataset_list", (list, tuple)) for dataset in dataset_list: check_param_type(dataset, "element in dataset_list", Data) self.existed_data_list = existed_data_list self.geometry_dict = geometry_dict self.dataset_list = dataset_list self.all_datasets = dataset_list if dataset_list else [] self.columns_list = None self._iterable_datasets = None self.num_dataset = len(dataset_list) if dataset_list else 0 if existed_data_list: self.num_dataset += len(existed_data_list) if geometry_dict: for geom in geometry_dict: self.num_dataset += len(geometry_dict[geom]) logger.info("Total datasets number: {}".format(self.num_dataset)) self.dataset_columns_map = {} self.column_index_map = {} self.dataset_constraint_map = {} def _create_dataset_from_geometry(self, geometry, geom_type="domain"): """create dataset from geometry.""" dataset_instance = _geomdata_dict.get(geom_type)(geometry) return dataset_instance def _get_all_datasets(self): """get all datasets""" if self.geometry_dict: for geom, types in self.geometry_dict.items(): for geom_type in types: dataset = self._create_dataset_from_geometry( geom, geom_type) self.all_datasets.append(dataset) if self.existed_data_list: for data_config in self.existed_data_list: dataset = ExistedDataset(data_config=data_config) self.all_datasets.append(dataset)
[文档] def create_dataset(self, batch_size=1, preprocess_fn=None, input_output_columns_map=None, shuffle=True, drop_remainder=True, prebatched_data=False, num_parallel_workers=1, num_shards=None, shard_id=None, python_multiprocessing=False, sampler=None): """ create the final mindspore type dataset to merge all the sub-datasets. Args: batch_size (int, optional): An int number of rows each batch is created with. Default: ``1``. preprocess_fn (Union[list[TensorOp], list[functions]], optional): List of operations to be applied on the dataset. Operations are applied in the order they appear in this list. Default: ``None``. input_output_columns_map (dict, optional): specifies which columns to replace and to what. The key is the column name to be replaced and the value is the name you want to replace with. There's no need to set this argument if all columns are not changed after mapping. Default: ``None``. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. Default: ``True``, expected order behavior shown in the table. drop_remainder (bool, optional): Determines whether or not to drop the last block whose data row number is less than batch size. If ``True``, and if there are less than batch_size rows available to make the last batch, then those rows will be dropped and not propagated to the child node. Default: ``True``. prebatched_data (bool, optional): Generate pre-batched data before create mindspore dataset. If ``True``, pre-batched data will be returned when get each sub-dataset data by index. Else, the batch operation will be done by mindspore dataset interface: dataset.batch. When batch_size is very large, it's recommended to set this option to be ``True`` in order to improve performance on host. Default: ``False``. num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel. Default: ``1``. num_shards (int, optional): Number of shards that the dataset will be divided into. Random accessible input is required. When this argument is specified, `num_samples` reflects the maximum sample number of per shard. Default: ``None``. shard_id (int, optional): The shard ID within num_shards. This argument must be specified only when num_shards is also specified. Random accessible input is required. Default: ``None``. python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multi-processing. This option could be beneficial if the function is computational heavy. Default: ``False``. sampler (Sampler, optional): Dataset Sampler. Default: ``None``. Returns: BatchDataset, dataset batched. Examples: >>> data = dataset.create_dataset() """ self._get_all_datasets() check_param_type(prebatched_data, "prebatched_data", data_type=bool) check_param_type(drop_remainder, "drop_remainder", data_type=bool) check_param_type(shuffle, "shuffle", data_type=bool) check_param_type(batch_size, "batch_size", data_type=int, exclude_type=bool) if prebatched_data and not drop_remainder: raise ValueError( "prebatched_data is not supported when drop_remained is set to be False") for dataset in self.all_datasets: prebatch_size = batch_size if prebatched_data else 1 prebatch_shuffle = shuffle if prebatched_data else False dataset._initialization( batch_size=prebatch_size, shuffle=prebatch_shuffle) self.columns_list = dataset.columns_list if not self.columns_list else self.columns_list + \ dataset.columns_list logger.info("Check initial all dataset, dataset: {}, columns_list: {}, data_size: {}".format( dataset.name, dataset.columns_list, len(dataset))) dataset = self._merge_all_datasets(shuffle=False if prebatched_data else shuffle, num_parallel_workers=num_parallel_workers, num_shards=num_shards, shard_id=shard_id, python_multiprocessing=python_multiprocessing) logger.info("Initial dataset size: {}".format( dataset.get_dataset_size())) logger.info("Get all dataset columns names: {}".format( self.columns_list)) self.dataset_columns_map, self.dataset_constraint_map, self.column_index_map = self._create_trace_maps() logger.info("Dataset columns map: {}".format(self.dataset_columns_map)) logger.info("Dataset column index map: {}".format( self.column_index_map)) logger.info("Dataset constraints map: {}".format( self.dataset_constraint_map)) if sampler: logger.info("Dataset uses sampler") dataset.use_sampler(sampler) if preprocess_fn: input_columns = copy.deepcopy(self.columns_list) check_param_type(input_output_columns_map, "input_output_columns_map", (type(None), dict)) if input_output_columns_map: new_columns_list, new_dataset_columns_map = self._update_columns_list( input_output_columns_map) self.columns_list = new_columns_list self.dataset_columns_map = new_dataset_columns_map self.column_index_map = {} for i in range(len(self.columns_list)): self.column_index_map[self.columns_list[i]] = i logger.info("Dataset columns map after preprocess: {}".format( self.dataset_columns_map)) logger.info("Dataset column index after preprocess: {}".format( self.column_index_map)) logger.info("Dataset constraints after preprocess: {}".format( self.dataset_constraint_map)) output_columns = self.columns_list dataset = dataset.map(operations=preprocess_fn, input_columns=input_columns, output_columns=output_columns, num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing) dataset = dataset.project(output_columns) logger.info("Get all dataset columns names after preprocess: {}".format( self.columns_list)) if not prebatched_data: dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_parallel_workers) logger.info("Final dataset size: {}".format( dataset.get_dataset_size())) return dataset
def _merge_all_datasets(self, shuffle=True, num_parallel_workers=1, num_shards=1, shard_id=0, python_multiprocessing=False): """merge all datasets""" self._iterable_datasets = _IterableDatasets(self.all_datasets) dataset = ds.GeneratorDataset(source=self._iterable_datasets, column_names=self.columns_list, shuffle=shuffle, num_parallel_workers=num_parallel_workers, num_shards=num_shards, shard_id=shard_id, python_multiprocessing=python_multiprocessing ) return dataset def _update_columns_list(self, input_output_columns_map): """update columns list""" new_dataset_columns_map = {} for dataset in self.all_datasets: columns_list = dataset.columns_list new_dataset_columns_map[dataset.name] = [] for column in columns_list: if column in input_output_columns_map.keys(): new_column = input_output_columns_map[column] if isinstance(new_column, list): new_dataset_columns_map[dataset.name] += new_column else: new_dataset_columns_map.get( dataset.name).append(new_column) else: new_dataset_columns_map.get(dataset.name).append(column) new_columns_list = [] for name in new_dataset_columns_map: new_columns_list += new_dataset_columns_map[name] return new_columns_list, new_dataset_columns_map
[文档] def get_columns_list(self): """get columns list Returns: list[str]. column names list of the final unified dataset. Examples: >>> columns_list = dataset.get_columns_list() """ if not self.columns_list: raise ValueError("Please call create_dataset() first before get final columns list to avoid unexpected " "error") return self.columns_list
def _create_trace_maps(self): """create trace maps""" dataset_columns_map = {} dataset_constraint_map = {} column_index_map = {} for dataset in self.all_datasets: name = dataset.name dataset_columns_map[name] = dataset.columns_list dataset_constraint_map[name] = dataset.constraint_type for i in range(len(self.columns_list)): column_index_map[self.columns_list[i]] = i return dataset_columns_map, dataset_constraint_map, column_index_map def __getitem__(self, index): if not self._iterable_datasets: raise ValueError( "Call create_dataset() before getting item by index to avoid unexpected error") return self._iterable_datasets[index]
[文档] def set_constraint_type(self, constraint_type="Equation"): """set constraint type of dataset Args: constraint_type (Union[str, dict]): The constraint type of specified dataset. If is string, the constraint type of all subdataset will be set to the same one. If is dict, the subdataset and it's constraint type is specified by the pair (key, value). Default: ``"Equation"``. It also supports ``"bc"``, ``"ic"``, ``"label"``, ``"function"``, and ``"custom"``. Examples: >>> dataset.set_constraint_type("Equation") """ if isinstance(constraint_type, str): logger.warning("Argument constraint_type: {} is str, the same type will be set for all of the sub-datasets" .format(constraint_type)) for datasets in self.all_datasets: datasets.set_constraint_type(constraint_type) elif isinstance(constraint_type, dict): for dataset in constraint_type.keys(): if dataset not in self.all_datasets: raise ValueError("Unknown dataset: {}. All sub-dataset are: {}".format( dataset, [data.name for data in self.all_datasets])) dataset.set_constraint_type(constraint_type[dataset]) else: raise TypeError("the type of constraint_type should be dict or str but got {}" .format(type(constraint_type)))
def __len__(self): if not self._iterable_datasets: raise ValueError( "Call create_dataset() before getting item by index to avoid unexpected error") return len(self._iterable_datasets)
class _IterableDatasets(): """get data iteratively""" def __init__(self, dataset_list): self.dataset_list = dataset_list dataset_size = [len(dataset) for dataset in dataset_list] logger.info("Get all dataset sizes: {}".format(dataset_size)) self.longest = max(dataset_size) def __getitem__(self, index): col_data = None for dataset_instance in self.dataset_list: item = dataset_instance[index] col_data = col_data + item if col_data else item return col_data def __len__(self): return self.longest