mindspore.dataset.Dataset.sync_wait

mindspore.dataset.Dataset.sync_wait(condition_name, num_batch=1, callback=None)[源代码]

为同步操作在数据集对象上添加阻塞条件。

参数:
  • condition_name (str) - 用于触发发送下一行数据的条件名称。

  • num_batch (int) - 每个epoch开始时无阻塞的batch数。默认值:1。

  • callback (function) - sync_update 操作中将调用的回调函数。默认值:None。

返回:

SyncWaitDataset,添加了阻塞条件的数据集对象。

异常:
  • RuntimeError - 条件名称已存在。

样例:

>>> import numpy as np
>>> def gen():
...     for i in range(100):
...         yield (np.array(i),)
>>>
>>> class Augment:
...     def __init__(self, loss):
...         self.loss = loss
...
...     def preprocess(self, input_):
...         return input_
...
...     def update(self, data):
...         self.loss = data["loss"]
>>>
>>> batch_size = 4
>>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
>>>
>>> aug = Augment(0)
>>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
>>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
>>> dataset = dataset.batch(batch_size)
>>> count = 0
>>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     assert data["input"][0] == count
...     count += batch_size
...     data = {"loss": count}
...     dataset.sync_update(condition_name="policy", data=data)