mindflow.cell.DDPMScheduler

View Source On Gitee
class mindflow.cell.DDPMScheduler(num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = 'squaredcos_cap_v2', prediction_type: str = 'epsilon', variance_type: str = 'fixed_small_log', clip_sample: bool = True, clip_sample_range: float = 1.0, thresholding: bool = False, sample_max_value: float = 1.0, dynamic_thresholding_ratio: float = 0.995, rescale_betas_zero_snr: bool = False, timestep_spacing: str = 'leading', compute_dtype=mstype.float32)[source]

DDPMScheduler is an implementation of the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs). Check Denoising Diffusion Probabilistic Models for more information.

Parameters
  • num_train_timesteps (int) – The number of diffusion steps to train the model. Default: 1000.

  • beta_start (float) – The starting beta value of inference. Default: 0.0001.

  • beta_end (float) – The final beta value. Default: 0.02.

  • beta_schedule (str) – The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from linear, scaled_linear, or squaredcos_cap_v2. Default: squaredcos_cap_v2.

  • prediction_type (str) – Prediction type of the scheduler function; can be epsilon (predicts the noise of the diffusion process), sample (directly predicts the noisy sample) or v_prediction (see section 2.4 of Imagen Video paper). Default: epsilon.

  • variance_type (str) – Clip the variance when adding noise to the denoised sample. Choose from fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range. Default: fixed_small_log.

  • clip_sample (bool) – Clip the predicted sample for numerical stability. Default: True.

  • clip_sample_range (float) – The maximum magnitude for sample clipping. Valid only when clip_sample=True. Default: 1.0.

  • thresholding (bool) – Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. Default: False.

  • sample_max_value (float) – The threshold value for dynamic thresholding. Valid only when thresholding=True. Default: 1.0.

  • dynamic_thresholding_ratio (float) – The ratio for the dynamic thresholding method. Valid only when thresholding=True. Default: 0.995.

  • timestep_spacing (str) – The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information. Choose from linspace, leading or trailing. Default: leading.

  • rescale_betas_zero_snr (bool) – Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to offset_noise. Default: False.

  • compute_dtype (mindspore.dtype) – the dtype of compute, it can be mstype.float32 or mstype.float16. Default: mstype.float32, indicates mindspore.float32.

Supported Platforms:

Ascend

Examples

>>> from mindspore import ops, dtype as mstype
>>> from mindflow.cell import DDPMScheduler
>>> scheduler = DDPMScheduler(num_train_timesteps=1000,
...                           beta_start=0.0001,
...                           beta_end=0.02,
...                           beta_schedule="squaredcos_cap_v2",
...                           prediction_type='epsilon',
...                           variance_type='fixed_small_log',
...                           clip_sample=True,
...                           clip_sample_range=1.0,
...                           thresholding=False,
...                           sample_max_value=1.,
...                           dynamic_thresholding_ratio=0.995,
...                           rescale_betas_zero_snr=False,
...                           timestep_spacing="leading",
...                           compute_dtype=mstype.float32)
>>> batch_size, seq_len, in_dim = 4, 256, 16
>>> original_samples = ops.randn([batch_size, seq_len, in_dim])
>>> noise = ops.randn([batch_size, seq_len, in_dim])
>>> timesteps = ops.randint(0, 100, [batch_size, 1])
>>> noised_sample = scheduler.add_noise(original_samples, noise, timesteps)
>>> print(noised_sample.shape)
(4, 256, 16)
>>> sample_timesteps = Tensor(np.array([60]*batch_size), dtype=mstype.int32)
>>> x_prev = scheduler.step(noise, noised_sample, sample_timesteps)
>>> print(x_prev.shape)
(4, 256, 16)
add_noise(original_samples: Tensor, noise: Tensor, timesteps: Tensor)

Diffusion add noise process.

Parameters
  • original_samples (Tensor) – The current samples.

  • noise (Tensor) – Random noise to be add into sample.

  • timesteps (Tensor) – The current discrete timestep in the diffusion chain.

Returns

Tensor, the noised sample of the next step.

set_timesteps(num_inference_steps)[source]

DDPM step inference timestep.

Parameters

num_inference_steps (int) – The denoising step number.

Raises

ValueError – If num_inference_steps is not equal to num_train_timesteps.

step(model_output, sample, timestep, predicted_variance=None)[source]

DDPM denoising step.

Parameters
  • model_output (Tensor) – The direct output from learned diffusion model.

  • sample (Tensor) – A current instance of a sample created by the diffusion process.

  • timestep (Tensor) – The current discrete timestep in the diffusion chain.

  • predicted_variance (Tensor) – The predicted variance. Default: None.

Returns

Tensor, the sample of last step.