mindspore.train.data_sink 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data sink help for minddata dataset"""
from functools import wraps
import mindspore.ops as ops
from mindspore import context
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import jit
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.dataset_helper import _has_dynamic_shape, _check_inputs
import mindspore.dataset as ds
from mindspore._c_expression import _set_dataset_mode_config
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes, _get_pipeline_stages
from mindspore._checkparam import Validator


def _init_sink_dataset(dataset, sink_size, input_signature, create_info):
    """
    Initialize data sinking
    """
    if hasattr(dataset, '__transfer_dataset__'):
        raise ValueError(f"The dataset has been used with network.")

    dataset_size = dataset.get_dataset_size()
    dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
    dynamic_shape = _has_dynamic_shape(dataset_shapes) or ds.config.get_dynamic_shape()

    # create transfer_dataset
    is_info_queue = (create_info and sink_size == 1 and dataset_size != 1 and
                     input_signature is None and not dynamic_shape and
                     context.get_context('device_target') == 'Ascend' and not context.get_context("enable_ge"))
    transfer_dataset = _exec_datagraph(dataset, sink_size, create_data_info_queue=is_info_queue)
    dataset.__transfer_dataset__ = transfer_dataset

    # send data
    transfer_dataset.send(-1)

    # create GetNext op
    if input_signature is not None:
        _check_inputs(input_signature, dataset_shapes, dataset_types)

    queue_name = transfer_dataset.queue_name
    if _need_to_full():
        device_num = _get_device_num() // _get_pipeline_stages()
        dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
    next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)

    _set_dataset_mode_config('sink')

    dataset.__transfer_dataset__ = transfer_dataset

    return next_op, is_info_queue


class _DataSinkAux:
    @staticmethod
    def __deepcopy__(memodict):
        return


def _get_next_op(dataset, ori_next_op, is_info_queue):
    """
    get the next operation.
    """

    if not is_info_queue:
        return ori_next_op, ''

    if not hasattr(dataset, '__sink_aux__'):
        dataset.__sink_aux__ = _DataSinkAux()
        dataset.__sink_aux__.next_ops = {}
        dataset.__sink_aux__.sink_funcs = {}

    queue_name = dataset.__transfer_dataset__.queue_name
    dataset_types, dataset_shapes = dataset.__transfer_dataset__.get_data_info()
    dataset_types = [pytype_to_dtype(x) for x in dataset_types]
    key = str(dataset_types) + str(dataset_shapes)
    if key in dataset.__sink_aux__.next_ops:
        next_op = dataset.__sink_aux__.next_ops[key]
    else:
        if _need_to_full():
            device_num = _get_device_num() // _get_pipeline_stages()
            dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
        next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)

    return next_op, (key, dataset_shapes, dataset_types)


def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
    """
    get the sink function.
    """
    if not is_info_queue:
        if not hasattr(dataset, '__sink_fun__'):
            if jit_config is None:
                dst_sink_fun = sink_fun
            else:
                dst_sink_fun = jit(sink_fun, jit_config=jit_config)
            dataset.__sink_fun__ = dst_sink_fun

        return dataset.__sink_fun__

    key = key_info[0]
    if key in dataset.__sink_aux__.sink_funcs:
        dst_sink_fun = dataset.__sink_aux__.sink_funcs[key]
    else:
        if jit_config is None:
            dst_sink_fun = sink_fun
        else:
            dst_sink_fun = jit(sink_fun, jit_config=jit_config)
        dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun

    return dst_sink_fun


[文档]def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None): """ A wrapper function to generate a function for the input function. Args: fn (Function): The Python function that will be run with dataset. dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in :class:`mindspore.dataset`, such as :class:`mindspore.dataset.ImageFolderDataset`. sink_size (int): Control the amount of data in each sink. `sink_size` must be positive integer. Default: 1. jit_config (JitConfig): Controls the execution mode(Graph mode/PyNative mode) of the generated function, and Jit config for compile. Default: None, means running in PyNative mode. input_signature (Union[Tensor, List or Tuple of Tensors]): The Tensor which describes the input arguments. The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise, TypeError will be raised. Default: None. Returns: Function, the generated function will be executed in data sinking mode. Raises: ValueError: If `sink_size` is not positive integer. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore import dataset as ds >>> >>> data = {"x": np.ones((1,), dtype=np.int32), "y": np.ones((1,), dtype=np.int32)} >>> dataset = ds.NumpySlicesDataset(data=data) >>> >>> def func_net(x, y): ... out = x + y ... return out >>> >>> sink_process = ms.train.data_sink(func_net, dataset, sink_size=1) >>> for _ in range(2): ... out = sink_process() ... print(out) 2 2 """ Validator.check_value_type("sink_size", sink_size, int, "Data sink") if sink_size <= 0: raise ValueError( f"The 'sink_size' must be positive, but got sink_size {sink_size}.") if context.get_context('device_target') not in ('Ascend', 'GPU'): raise ValueError( f"Data sinking supports ascend or gpu device target, " f"but device target is {context.get_context('device_target')}.") loop = sink_size create_info = True if jit_config is None: create_info = (loop == 1) loop = 1 ori_next_op, is_info_queue = _init_sink_dataset(dataset, loop, input_signature, create_info) @wraps(fn) def sink_process(*args, **kwargs): next_op, key_info = _get_next_op(dataset, ori_next_op, is_info_queue) def sink_fun(): data = next_op() out = fn(*data) return out real_sink_fun = _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config) loop = sink_size if jit_config is not None: loop = 1 out = None for _ in range(loop): out = real_sink_fun() return out return sink_process