mindflow.cell.DDIMPipeline
- class mindflow.cell.DDIMPipeline(model, scheduler, batch_size, seq_len, num_inference_steps=1000, compute_dtype=mstype.float32)[source]
Pipeline for DDIM generation.
- Parameters
model (nn.Cell) – The diffusion backbone model.
scheduler (DDIMScheduler) – A scheduler to be used in combination with model to denoise samples.
batch_size (int) – The number of images to generate.
seq_len (int) – Sequence length of inputs.
num_inference_steps (int) – Number of Denoising steps. Default:
1000
.compute_dtype (mindspore.dtype) – The dtype of compute, it can be mstype.float32 or mstype.float16. Default:
mstype.float32
, indicatesmindspore.float32
.
- Raises
TypeError – If scheduler is not DDIMScheduler type.
ValueError – If num_inference_steps is greater than scheduler.num_train_timesteps .
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops, dtype as mstype >>> from mindflow.cell import DDIMPipeline, DDIMScheduler, ConditionDiffusionTransformer >>> # init params >>> in_dim, out_dim, hidden_dim, cond_dim, layers, heads, seq_len, batch_size = 16, 16, 256, 4, 3, 4, 256, 8 >>> # init condition >>> cond = ops.randn([8, 4]) >>> # init model and scheduler >>> net = ConditionDiffusionTransformer(in_channels=in_dim, ... out_channels=in_dim, ... cond_channels=cond_dim, ... hidden_channels=hidden_dim, ... layers=layers, ... heads=heads, ... time_token_cond=True, ... compute_dtype=mstype.float32) >>> num_train_timesteps = 100 >>> scheduler = DDIMScheduler(num_train_timesteps=num_train_timesteps, ... beta_start=0.0001, ... beta_end=0.02, ... beta_schedule="squaredcos_cap_v2", ... prediction_type='epsilon', ... clip_sample=True, ... clip_sample_range=1.0, ... thresholding=False, ... sample_max_value=1., ... dynamic_thresholding_ratio=0.995, ... rescale_betas_zero_snr=False, ... timestep_spacing="leading", ... compute_dtype=mstype.float32) >>> # init pipeline >>> pipe = DDIMPipeline(model=net, scheduler=scheduler, >>> batch_size=batch_size, seq_len=seq_len, num_inference_steps=num_train_timesteps) >>> # run pipeline in inference (sample random noise and denoise) >>> image = pipe(cond) >>> print(image.shape) (8, 256, 16)