mindflow.data.existed_data 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#pylint: disable=W0223
#pylint: disable=W0221
"""
This dataset module supports npy type of datasets. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
"""
from __future__ import absolute_import

import numpy as np
from mindspore import log as logger

from .data_base import Data, ExistedDataConfig
from ..utils.check_func import check_param_type


[文档]class ExistedDataset(Data): r""" Creates a dataset with given data path. Note: The ``'npy'`` data format is supported now. Parameters: name (str, optional): specifies the name of dataset. Default: ``None``. If `data_config` is ``None``, the `name` should not be ``None``. data_dir (Union[str, list, tuple], optional): the path of existed data files. Default: ``None``. If `data_config` is ``None``, the `data_dir` should not be ``None``. columns_list (Union[str, list, tuple], optional): list of column names of the dataset. Default: ``None``. If `data_config` is ``None``, the `columns_list` should not be ``None``. data_format (str, optional): the format of existed data files. Default: ``'npy'``. constraint_type (str, optional): specifies the constraint type of the created dataset. Default: "Label". random_merge (bool, optional): specifies whether randomly merge the given datasets. Default: ``True``. data_config (ExistedDataConfig, optional): Instance of ExistedDataConfig which collect the info described above. Default: ``None``. If it's not ``None``, the dataset class will be create by using it for simplifying. If it's ``None``, the info of (name, data_dir, columns_list, data_format, constraint_type, random_merge) will be used for replacement. Raises: ValueError: Argument `name`/`data_dir`/`columns_list` is ``None`` when `data_config` is ``None``. TypeError: If data_config is not a instance of ExistedDataConfig. ValueError: If data_format is not ``'npy'``. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindflow.data import ExistedDataConfig, ExistedDataset >>> data_config = ExistedDataConfig(name='exist', ... data_dir=['./data.npy'], ... columns_list=['input_data'], data_format="npy", constraint_type="Equation") >>> dataset = ExistedDataset(data_config=data_config) """ def __init__(self, name=None, data_dir=None, columns_list=None, data_format='npy', constraint_type="Label", random_merge=True, data_config=None): if not data_config: if not name or not data_dir or not columns_list: raise ValueError("If data_config is None, argument: name, data_dir and columns_list should not be" " None, but got name: {}, data_dir: {}, columns_list: {}" .format(name, data_dir, columns_list)) data_config = ExistedDataConfig(name, data_dir, columns_list, data_format, constraint_type, random_merge) check_param_type(data_config, "data_config", data_type=ExistedDataConfig) name = data_config.name columns_list = [name + "_" + column_name for column_name in data_config.columns_list] constraint_type = data_config.constraint_type self.data_dir = data_config.data_dir self._data_format = data_config.data_format self._random_merge = data_config.random_merge self.data = None self.data_size = None self.batch_size = 1 self.shuffle = False self.batched_data_size = None self._index = None self.load_data = {"npy": self._load_npy_data} super(ExistedDataset, self).__init__(name, columns_list, constraint_type) def _initialization(self, batch_size=1, shuffle=False): """load datasets from given paths""" data = self.load_data.get(self._data_format.lower())() if not isinstance(data, tuple): data = (data,) self.data = data self.data_size = len(data[0]) self.batch_size = batch_size if batch_size > self.data_size: raise ValueError("If prebatch data, batch_size: {} should not be larger than data size: {}.".format( batch_size, self.data_size )) self.batched_data_size = self.data_size // batch_size self.shuffle = shuffle self._index = np.arange(self.data_size) logger.info("Get existed dataset: {}, columns: {}, size: {}, batched_size: {}, shuffle: {}".format( self.name, self.columns_list, self.data_size, self.batched_data_size, self.shuffle)) return data def __getitem__(self, index): if not self.data: self._initialization() if self._random_merge: index = np.random.randint(0, self.batched_data_size) if index >= self.batched_data_size else index else: index = index % self.batched_data_size if index >= self.batched_data_size else index if self.shuffle and index % self.batched_data_size == 0: self._index = np.random.permutation(self.data_size) col_data = None for i in range(len(self.columns_list)): temp_data = self.data[i][self._index[index]] if self.batch_size == 1 else \ self.data[i][self._index[index * self.batch_size : (index + 1) * self.batch_size]] col_data = (temp_data,) if col_data is None else col_data + (temp_data,) return col_data def _load_npy_data(self): """ Load npy-type data from exited file. For every column the data shape should be 2D, i.e. (batch_size, dim) """ data = tuple() for path in self.data_dir: logger.info("Read data from file: {}".format(path)) data_tmp = np.load(path) data += (data_tmp.astype(np.float32),) logger.info("Load npy data size: {}".format(len(data[0]))) return data def __len__(self): if not self.data: self._initialization() return self.batched_data_size