Source code for mindspore.train.data_sink

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data sink help for minddata dataset

Note:
    This feature is a beta feature, and we are still improving its functionality.
"""

import math
from functools import wraps
import mindspore.ops as ops
from mindspore import context
from mindspore.common.api import ms_function
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.dataset_helper import _has_dynamic_shape
import mindspore.dataset as ds
from mindspore._c_expression import _set_dataset_mode_config


def _init_sink_dataset(dataset, steps, sink_size):
    """
    Initialize data sinking
    """
    if hasattr(dataset, '__transfer_dataset__'):
        raise ValueError(f"The dataset has been used with network.")

    dataset_size = dataset.get_dataset_size()
    if dataset_size == 0:
        raise ValueError(f"For data sink, dataset size must be greater than 0, but got 0.")
    dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
    dynamic_shape = _has_dynamic_shape(dataset_shapes) or ds.config.get_dynamic_shape()

    # create transfer_dataset
    create_data_info_queue = (sink_size == 1 and dataset_size != 1 and
                              context.get_context('device_target') == 'Ascend' and not dynamic_shape)
    transfer_dataset = _exec_datagraph(dataset, sink_size, create_data_info_queue=create_data_info_queue)

    # send data
    sink_count = math.ceil(steps/dataset_size)
    transfer_dataset.send(sink_count)

    # create GetNext op
    queue_name = transfer_dataset.queue_name
    next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)

    _set_dataset_mode_config('sink')

    dataset.__transfer_dataset__ = transfer_dataset

    return next_op, create_data_info_queue


[文档]def data_sink(fn, dataset, steps, sink_size=1, jit=False): """ A wrapper function to generate a function for the input function. The generated function will be executed in data sinking mode. Args: fn (Function): The Python function that will be run with dataset. dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in :class:`mindspore.dataset`, such as :class:`mindspore.dataset.ImageFolderDataset`. steps (int): The total running steps. `steps` must be positive integer. sink_size (int): Control the amount of data in each sink. `sink_size` must be positive integer. Default: 1. jit (bool): Controls the execution mode(Graph mode/PyNative mode) of the generated function. Default: False, means running in PyNative mode. Returns: Function, the generated function will be executed in data sinking mode. Raises: ValueError: If `steps` or `sink_size` is not positive integer. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore import dataset as ds >>> >>> data = {"x": np.ones((1,), dtype=np.int32), "y": np.ones((1,), dtype=np.int32)} >>> dataset = ds.NumpySlicesDataset(data=data) >>> >>> def func_net(x, y): ... out = x + y ... return out >>> >>> sink_process = ms.train.data_sink(func_net, dataset, steps=2, sink_size=1) >>> for _ in range(2): ... out = sink_process() ... print(out) 2 2 """ if sink_size <= 0 or steps <= 0: raise ValueError( f"The 'steps' and 'sink_size' must be positive, but got steps {steps} sink_size {sink_size}.") next_op, _ = _init_sink_dataset(dataset, steps, sink_size) @wraps(fn) def sink_process(*args, **kwargs): def sink_fun(): data = next_op() out = fn(*data) return out real_sink_fun = sink_fun loop = sink_size if jit: loop = 1 if not hasattr(dataset, '__sink_fun__'): dataset.__sink_fun__ = ms_function(sink_fun) real_sink_fun = dataset.__sink_fun__ out = None for _ in range(loop): out = real_sink_fun() return out return sink_process