mindflow.cell.DDPMPipeline
- class mindflow.cell.DDPMPipeline(model, scheduler, batch_size, seq_len, num_inference_steps=1000, compute_dtype=mstype.float32)[源代码]
DDPM采样过程控制实现。
- 参数:
model (nn.Cell) - 训练模型。
scheduler (DDPMScheduler) - 噪声控制器,用于去噪。
batch_size (int) - batch大小。
seq_len (int) - 序列长度。
num_inference_steps (int) - 采样的步数。默认值:
1000
。compute_dtype (mindspore.dtype) - 数据类型。默认值:
mstype.float32
,表示mindspore.float32
。
- 异常:
TypeError - 如果 scheduler 不是 DDPMScheduler 类型。
ValueError - 如果 num_inference_steps 不等于 scheduler.num_train_timesteps 。
- 支持平台:
Ascend
样例:
>>> from mindspore import ops, dtype as mstype >>> from mindflow.cell import DDPMPipeline, DDPMScheduler, ConditionDiffusionTransformer >>> # init params >>> in_dim, out_dim, hidden_dim, cond_dim, layers, heads, seq_len, batch_size = 16, 16, 256, 4, 3, 4, 256, 8 >>> # init condition >>> cond = ops.randn([8, 4]) >>> # init model and scheduler >>> net = ConditionDiffusionTransformer(in_channels=in_dim, ... out_channels=in_dim, ... cond_channels=cond_dim, ... hidden_channels=hidden_dim, ... layers=layers, ... heads=heads, ... time_token_cond=True, ... compute_dtype=mstype.float32) >>> num_train_timesteps = 100 >>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, ... beta_start=0.0001, ... beta_end=0.02, ... beta_schedule="squaredcos_cap_v2", ... clip_sample=True, ... clip_sample_range=1.0, ... thresholding=False, ... dynamic_thresholding_ratio=0.995, ... rescale_betas_zero_snr=False, ... timestep_spacing="leading", ... compute_dtype=mstype.float32) >>> # init pipeline >>> pipe = DDPMPipeline(model=net, scheduler=scheduler, >>> batch_size=batch_size, seq_len=seq_len, num_inference_steps=num_train_timesteps) >>> # run pipeline in inference (sample random noise and denoise) >>> image = pipe(cond) >>> print(image.shape) (8, 256, 16)