mindspore.data_sink

View Source On Gitee
mindspore.data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None)[source]

A wrapper function to generate a function for the input function.

Note

When using data sinking, the dataset will be automatically looped to the device. The device side can cache up to 100 batches of data and occupy no more than 2GB of memory. At this time, only the number of steps for each sinking sink_size needs to be considered. sink_size defaults to 1, indicating that each epoch only takes one batch of data from the cache for training and outputs a loss. If sink_size is greater than 1, each epoch takes out sink_size batches of data from the cache for training and outputs a loss.

Parameters
  • fn (Function) – The Python function that will be run with dataset.

  • dataset (Dataset) – The dataset iterator. The dataset can be generated by dataset generator API in mindspore.dataset, such as mindspore.dataset.ImageFolderDataset.

  • sink_size (int) – Control the amount of data in each sink. sink_size must be positive integer. Default: 1 .

  • jit_config (JitConfig) – Controls the execution mode(Graph mode/PyNative mode) of the generated function, and Jit config for compile. Default: None , means running in PyNative mode.

  • input_signature (Union[Tensor, List or Tuple of Tensors]) – The Tensor which describes the input arguments. The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified, each input to fn must be a Tensor. And the input parameters of fn cannot accept **kwargs. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise, TypeError will be raised. Default: None .

Returns

Function, the generated function will be executed in data sinking mode.

Raises

ValueError – If sink_size is not positive integer.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import dataset as ds
>>>
>>> data = {"x": np.ones((1,), dtype=np.int32), "y": np.ones((1,), dtype=np.int32)}
>>> dataset = ds.NumpySlicesDataset(data=data)
>>>
>>> def func_net(x, y):
...     out = x + y
...     return out
>>>
>>> sink_process = ms.data_sink(func_net, dataset, sink_size=1)
>>> for _ in range(2):
...     out = sink_process()
...     print(out)
2
2