mindspore.dataset.Dataset.sync_update

mindspore.dataset.Dataset.sync_update(condition_name, num_batch=None, data=None)[源代码]

释放阻塞条件并使用给定数据触发回调函数。

参数:
  • condition_name (str) - 用于触发发送下一个数据行的条件名称。

  • num_batch (Union[int, None]) - 释放的batch(row)数。当 num_batch 为None时,将默认为 sync_wait 操作指定的值。默认值:None。

  • data (Any) - 用户自定义传递给回调函数的数据。默认值:None。

样例:

>>> import numpy as np
>>>
>>>
>>> def gen():
...     for i in range(100):
...         yield (np.array(i),)
>>>
>>>
>>> class Augment:
...     def __init__(self, loss):
...         self.loss = loss
...
...     def preprocess(self, input_):
...         return input_
...
...     def update(self, data):
...         self.loss = data["loss"]
>>>
>>>
>>> batch_size = 10
>>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
>>> aug = Augment(0)
>>> dataset = dataset.sync_wait(condition_name='', num_batch=1)
>>> dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
>>> dataset = dataset.batch(batch_size)
>>>
>>> count = 0
>>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     count += 1
...     data = {"loss": count}
...     dataset.sync_update(condition_name="", data=data)