mindflow.cell.DiffusionTrainer

View Source On Gitee
class mindflow.cell.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1, loss_type='l1')[source]

Diffusion Trainer base class.

Parameters
  • model (nn.Cell) – The diffusion backbone model.

  • scheduler (DiffusionScheduler) – DDPM or DDIM scheduler.

  • objective (str) – Prediction type of the scheduler function; can be pred_noise (predicts the noise of the diffusion process), pred_x0 (predicts the original sample) or pred_v (see section 2.4 of Imagen Video paper). Default: pred_noise.

  • p2_loss_weight_gamma (float) – p2 loss weight gamma, from Perception Prioritized Training of Diffusion Models. Default: 0.

  • p2_loss_weight_k (float) –

    p2 loss weight k, from Perception Prioritized Training of Diffusion Models. Default: 1.

  • loss_type (str) – The type of loss, it can be l1 or l2. Default: l1.

Raises

TypeError – If scheduler is not DiffusionScheduler type.

Supported Platforms:

Ascend

Examples

>>> from mindspore import ops, dtype as mstype
>>> from mindflow.cell import DDPMScheduler, ConditionDiffusionTransformer, DiffusionTrainer
>>> # init params
>>> batch_size, seq_len, in_dim, cond_dim, num_train_timesteps = 4, 256, 16, 4, 100
>>> original_samples = ops.randn([batch_size, seq_len, in_dim])
>>> noise = ops.randn([batch_size, seq_len, in_dim])
>>> timesteps = ops.randint(0, num_train_timesteps, [batch_size, 1])
>>> cond = ops.randn([batch_size, cond_dim])
>>> # init model and scheduler
>>> net = ConditionDiffusionTransformer(in_channels=in_dim,
...                                     out_channels=in_dim,
...                                     cond_channels=cond_dim,
...                                     hidden_channels=hidden_dim,
...                                     layers=layers,
...                                     heads=heads,
...                                     time_token_cond=True,
...                                     compute_dtype=mstype.float32)
>>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps,
...                           beta_start=0.0001,
...                           beta_end=0.02,
...                           beta_schedule="squaredcos_cap_v2",
...                           clip_sample=True,
...                           clip_sample_range=1.0,
...                           thresholding=False,
...                           dynamic_thresholding_ratio=0.995,
...                           rescale_betas_zero_snr=False,
...                           timestep_spacing="leading",
...                           compute_dtype=mstype.float32)
>>> # init trainer
>>> trainer = DiffusionTrainer(net,
...                            scheduler,
...                            objective='pred_noise',
...                            p2_loss_weight_gamma=0,
...                            p2_loss_weight_k=1,
...                            loss_type='l2')
>>> loss = trainer.get_loss(original_samples, noise, timesteps, condition)
get_loss(original_samples: Tensor, noise: Tensor, timesteps: Tensor, condition: Tensor = None)[source]

Calculate the forward loss of diffusion process.

Parameters
  • original_samples (Tensor) – The direct output from learned diffusion model.

  • noise (Tensor) – A current instance of a noise sample created by the diffusion process.

  • timesteps (Tensor) – The current discrete timestep in the diffusion chain.

  • condition (Tensor) – The condition for desired outputs. Default: None.

Returns

Tensor, the model forward loss.