Source code for mindspore.mindrecord.filereader

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module is to read data from mindrecord.
"""
from .shardreader import ShardReader
from .shardheader import ShardHeader
from .shardutils import populate_data
from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
from .common.exceptions import ParamValueError, ParamTypeError

__all__ = ['FileReader']

[docs]class FileReader: """ Class to read MindRecord File series. Args: file_name (str): File name of MindRecord File. num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). It should not be smaller than 1 or larger than the number of CPU. columns (list[str], optional): List of fields which correspond data would be read (default=None). operator(int, optional): Reserved parameter for operators (default=None). Raises: ParamValueError: If file_name, num_consumer or columns is invalid. """ def __init__(self, file_name, num_consumer=4, columns=None, operator=None): check_filename(file_name) self._file_name = file_name if num_consumer is not None: if isinstance(num_consumer, int): if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT(): raise ParamValueError("Consumer number should between {} and {}." .format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT())) else: raise ParamValueError("Consumer number is illegal.") else: raise ParamValueError("Consumer number is illegal.") if columns: if isinstance(columns, list): self._columns = columns else: raise ParamTypeError('columns', 'list') else: self._columns = None self._reader = ShardReader() self._reader.open(file_name, num_consumer, columns, operator) self._header = ShardHeader(self._reader.get_header()) self._reader.launch()
[docs] def get_next(self): """ Yield a batch of data according to columns at a time. Yields: dict: keys is the same as columns. Raises: MRMUnsupportedSchemaError: If schema is invalid. """ iterator = self._reader.get_next() while iterator: for blob, raw in iterator: yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) iterator = self._reader.get_next()
[docs] def finish(self): """ Stop reader worker. Raises: MRMFinishError: If failed to finish worker threads. """ return self._reader.finish()
[docs] def close(self): """Stop reader worker and close File.""" return self._reader.close()