mindspore.dataset.Dataset.sync_wait
- mindspore.dataset.Dataset.sync_wait(condition_name, num_batch=1, callback=None)[源代码]
为同步操作在数据集对象上添加阻塞条件。
- 参数:
condition_name (str) - 用于触发发送下一行数据的条件名称。
num_batch (int) - 每个epoch开始时无阻塞的batch数。默认值:1。
callback (function) - sync_update 操作中将调用的回调函数。默认值:None。
- 返回:
SyncWaitDataset,添加了阻塞条件的数据集对象。
- 异常:
RuntimeError - 条件名称已存在。
样例:
>>> import numpy as np >>> def gen(): ... for i in range(100): ... yield (np.array(i),) >>> >>> class Augment: ... def __init__(self, loss): ... self.loss = loss ... ... def preprocess(self, input_): ... return input_ ... ... def update(self, data): ... self.loss = data["loss"] >>> >>> batch_size = 4 >>> dataset = ds.GeneratorDataset(gen, column_names=["input"]) >>> >>> aug = Augment(0) >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) >>> dataset = dataset.batch(batch_size) >>> count = 0 >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... assert data["input"][0] == count ... count += batch_size ... data = {"loss": count} ... dataset.sync_update(condition_name="policy", data=data)