# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data sink help for minddata dataset
Note:
This feature is a beta feature, and we are still improving its functionality.
"""
import math
from functools import wraps
import mindspore.ops as ops
from mindspore import context
from mindspore.common.api import ms_function
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.dataset_helper import _has_dynamic_shape
import mindspore.dataset as ds
from mindspore._c_expression import _set_dataset_mode_config
def _init_sink_dataset(dataset, steps, sink_size):
"""
Initialize data sinking
"""
if hasattr(dataset, '__transfer_dataset__'):
raise ValueError(f"The dataset has been used with network.")
dataset_size = dataset.get_dataset_size()
if dataset_size == 0:
raise ValueError(f"For data sink, dataset size must be greater than 0, but got 0.")
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
dynamic_shape = _has_dynamic_shape(dataset_shapes) or ds.config.get_dynamic_shape()
# create transfer_dataset
create_data_info_queue = (sink_size == 1 and dataset_size != 1 and
context.get_context('device_target') == 'Ascend' and not dynamic_shape)
transfer_dataset = _exec_datagraph(dataset, sink_size, create_data_info_queue=create_data_info_queue)
# send data
sink_count = math.ceil(steps/dataset_size)
transfer_dataset.send(sink_count)
# create GetNext op
queue_name = transfer_dataset.queue_name
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
_set_dataset_mode_config('sink')
dataset.__transfer_dataset__ = transfer_dataset
return next_op, create_data_info_queue
[文档]def data_sink(fn, dataset, steps, sink_size=1, jit=False):
"""
A wrapper function to generate a function for the input function.
The generated function will be executed in data sinking mode.
Args:
fn (Function): The Python function that will be run with dataset.
dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in
:class:`mindspore.dataset`, such as :class:`mindspore.dataset.ImageFolderDataset`.
steps (int): The total running steps. `steps` must be positive integer.
sink_size (int): Control the amount of data in each sink. `sink_size` must be positive integer. Default: 1.
jit (bool): Controls the execution mode(Graph mode/PyNative mode) of the generated function. Default: False,
means running in PyNative mode.
Returns:
Function, the generated function will be executed in data sinking mode.
Raises:
ValueError: If `steps` or `sink_size` is not positive integer.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import dataset as ds
>>>
>>> data = {"x": np.ones((1,), dtype=np.int32), "y": np.ones((1,), dtype=np.int32)}
>>> dataset = ds.NumpySlicesDataset(data=data)
>>>
>>> def func_net(x, y):
... out = x + y
... return out
>>>
>>> sink_process = ms.train.data_sink(func_net, dataset, steps=2, sink_size=1)
>>> for _ in range(2):
... out = sink_process()
... print(out)
2
2
"""
if sink_size <= 0 or steps <= 0:
raise ValueError(
f"The 'steps' and 'sink_size' must be positive, but got steps {steps} sink_size {sink_size}.")
next_op, _ = _init_sink_dataset(dataset, steps, sink_size)
@wraps(fn)
def sink_process(*args, **kwargs):
def sink_fun():
data = next_op()
out = fn(*data)
return out
real_sink_fun = sink_fun
loop = sink_size
if jit:
loop = 1
if not hasattr(dataset, '__sink_fun__'):
dataset.__sink_fun__ = ms_function(sink_fun)
real_sink_fun = dataset.__sink_fun__
out = None
for _ in range(loop):
out = real_sink_fun()
return out
return sink_process