mindflow.cell.DDIMPipeline

查看源文件
class mindflow.cell.DDIMPipeline(model, scheduler, batch_size, seq_len, num_inference_steps=1000, compute_dtype=mstype.float32)[源代码]

DDIM采样过程控制实现。

参数:
  • model (nn.Cell) - 训练模型。

  • scheduler (DDIMScheduler) - 噪声控制器,用于去噪。

  • batch_size (int) - batch大小。

  • seq_len (int) - 序列长度。

  • num_inference_steps (int) - 采样的步数。默认值: 1000

  • compute_dtype (mindspore.dtype) - 数据类型。默认值: mstype.float32 ,表示 mindspore.float32

异常:
  • TypeError - 如果 scheduler 不是 DDIMScheduler 类型。

  • ValueError - 如果 num_inference_steps 大于 scheduler.num_train_timesteps

支持平台:

Ascend

样例:

>>> from mindspore import ops, dtype as mstype
>>> from mindflow.cell import DDIMPipeline, DDIMScheduler, ConditionDiffusionTransformer
>>> # init params
>>> in_dim, out_dim, hidden_dim, cond_dim, layers, heads, seq_len, batch_size = 16, 16, 256, 4, 3, 4, 256, 8
>>> # init condition
>>> cond = ops.randn([8, 4])
>>> # init model and scheduler
>>> net = ConditionDiffusionTransformer(in_channels=in_dim,
...                                     out_channels=in_dim,
...                                     cond_channels=cond_dim,
...                                     hidden_channels=hidden_dim,
...                                     layers=layers,
...                                     heads=heads,
...                                     time_token_cond=True,
...                                     compute_dtype=mstype.float32)
>>> num_train_timesteps = 100
>>> scheduler = DDIMScheduler(num_train_timesteps=num_train_timesteps,
...                           beta_start=0.0001,
...                           beta_end=0.02,
...                           beta_schedule="squaredcos_cap_v2",
...                           prediction_type='epsilon',
...                           clip_sample=True,
...                           clip_sample_range=1.0,
...                           thresholding=False,
...                           sample_max_value=1.,
...                           dynamic_thresholding_ratio=0.995,
...                           rescale_betas_zero_snr=False,
...                           timestep_spacing="leading",
...                           compute_dtype=mstype.float32)
>>> # init pipeline
>>> pipe = DDIMPipeline(model=net, scheduler=scheduler,
>>>                     batch_size=batch_size, seq_len=seq_len, num_inference_steps=num_train_timesteps)
>>> # run pipeline in inference (sample random noise and denoise)
>>> image = pipe(cond)
>>> print(image.shape)
(8, 256, 16)