mindspore.dataset.Dataset.sync_update
- mindspore.dataset.Dataset.sync_update(condition_name, num_batch=None, data=None)[源代码]
释放阻塞条件并使用给定数据触发回调函数。
- 参数:
condition_name (str) - 用于触发发送下一个数据行的条件名称。
num_batch (Union[int, None]) - 释放的batch(row)数。当 num_batch 为None时,将默认为 sync_wait 操作指定的值。默认值:None。
data (Any) - 用户自定义传递给回调函数的数据。默认值:None。
样例:
>>> import numpy as np >>> >>> >>> def gen(): ... for i in range(100): ... yield (np.array(i),) >>> >>> >>> class Augment: ... def __init__(self, loss): ... self.loss = loss ... ... def preprocess(self, input_): ... return input_ ... ... def update(self, data): ... self.loss = data["loss"] >>> >>> >>> batch_size = 10 >>> dataset = ds.GeneratorDataset(gen, column_names=["input"]) >>> aug = Augment(0) >>> dataset = dataset.sync_wait(condition_name='', num_batch=1) >>> dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) >>> dataset = dataset.batch(batch_size) >>> >>> count = 0 >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... count += 1 ... data = {"loss": count} ... dataset.sync_update(condition_name="", data=data)