mindflow.cell.DiffusionTrainer
- class mindflow.cell.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1, loss_type='l1')[source]
Diffusion Trainer base class.
- Parameters
model (nn.Cell) – The diffusion backbone model.
scheduler (DiffusionScheduler) – DDPM or DDIM scheduler.
objective (str) – Prediction type of the scheduler function; can be pred_noise (predicts the noise of the diffusion process), pred_x0 (predicts the original sample) or pred_v (see section 2.4 of Imagen Video paper). Default:
pred_noise
.p2_loss_weight_gamma (float) – p2 loss weight gamma, from Perception Prioritized Training of Diffusion Models. Default:
0
.p2_loss_weight_k (float) –
p2 loss weight k, from Perception Prioritized Training of Diffusion Models. Default:
1
.loss_type (str) – The type of loss, it can be l1 or l2. Default:
l1
.
- Raises
TypeError – If scheduler is not DiffusionScheduler type.
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops, dtype as mstype >>> from mindflow.cell import DDPMScheduler, ConditionDiffusionTransformer, DiffusionTrainer >>> # init params >>> batch_size, seq_len, in_dim, cond_dim, num_train_timesteps = 4, 256, 16, 4, 100 >>> original_samples = ops.randn([batch_size, seq_len, in_dim]) >>> noise = ops.randn([batch_size, seq_len, in_dim]) >>> timesteps = ops.randint(0, num_train_timesteps, [batch_size, 1]) >>> cond = ops.randn([batch_size, cond_dim]) >>> # init model and scheduler >>> net = ConditionDiffusionTransformer(in_channels=in_dim, ... out_channels=in_dim, ... cond_channels=cond_dim, ... hidden_channels=hidden_dim, ... layers=layers, ... heads=heads, ... time_token_cond=True, ... compute_dtype=mstype.float32) >>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, ... beta_start=0.0001, ... beta_end=0.02, ... beta_schedule="squaredcos_cap_v2", ... clip_sample=True, ... clip_sample_range=1.0, ... thresholding=False, ... dynamic_thresholding_ratio=0.995, ... rescale_betas_zero_snr=False, ... timestep_spacing="leading", ... compute_dtype=mstype.float32) >>> # init trainer >>> trainer = DiffusionTrainer(net, ... scheduler, ... objective='pred_noise', ... p2_loss_weight_gamma=0, ... p2_loss_weight_k=1, ... loss_type='l2') >>> loss = trainer.get_loss(original_samples, noise, timesteps, condition)
- get_loss(original_samples: Tensor, noise: Tensor, timesteps: Tensor, condition: Tensor = None)[source]
Calculate the forward loss of diffusion process.
- Parameters
original_samples (Tensor) – The direct output from learned diffusion model.
noise (Tensor) – A current instance of a noise sample created by the diffusion process.
timesteps (Tensor) – The current discrete timestep in the diffusion chain.
condition (Tensor) – The condition for desired outputs. Default:
None
.
- Returns
Tensor, the model forward loss.