mindflow.cell.DiffusionTrainer
- class mindflow.cell.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1, loss_type='l1')[源代码]
扩散模型训练控制实现。
- 参数:
model (nn.Cell) - 用于扩散模型训练的骨干网络。
scheduler (DiffusionScheduler) - 用于训练的噪声控制器。
objective (str) - 扩散模型预测结果的类型。默认值:
pred_noise
。支持以下类型:pred_noise
,pred_v
和pred_x0
。p2_loss_weight_gamma (float) - p2 loss权重 gamma ,具体信息查看 Perception Prioritized Training of Diffusion Models 。默认值:
0.0
。p2_loss_weight_k (float) - p2 loss权重 k ,具体信息查看 Perception Prioritized Training of Diffusion Models 。默认值:
1
。loss_type (str) - loss函数类型。默认值:
l1
。支持以下类型:l1
和l2
。
- 异常:
TypeError - 如果 scheduler 不是 DiffusionScheduler 类型。
- 支持平台:
Ascend
样例:
>>> from mindspore import ops, dtype as mstype >>> from mindflow.cell import DDPMScheduler, ConditionDiffusionTransformer, DiffusionTrainer >>> # init params >>> batch_size, seq_len, in_dim, cond_dim, num_train_timesteps = 4, 256, 16, 4, 100 >>> original_samples = ops.randn([batch_size, seq_len, in_dim]) >>> noise = ops.randn([batch_size, seq_len, in_dim]) >>> timesteps = ops.randint(0, num_train_timesteps, [batch_size, 1]) >>> cond = ops.randn([batch_size, cond_dim]) >>> # init model and scheduler >>> net = ConditionDiffusionTransformer(in_channels=in_dim, ... out_channels=in_dim, ... cond_channels=cond_dim, ... hidden_channels=hidden_dim, ... layers=layers, ... heads=heads, ... time_token_cond=True, ... compute_dtype=mstype.float32) >>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, ... beta_start=0.0001, ... beta_end=0.02, ... beta_schedule="squaredcos_cap_v2", ... clip_sample=True, ... clip_sample_range=1.0, ... thresholding=False, ... dynamic_thresholding_ratio=0.995, ... rescale_betas_zero_snr=False, ... timestep_spacing="leading", ... compute_dtype=mstype.float32) >>> # init trainer >>> trainer = DiffusionTrainer(net, ... scheduler, ... objective='pred_noise', ... p2_loss_weight_gamma=0, ... p2_loss_weight_k=1, ... loss_type='l2') >>> loss = trainer.get_loss(original_samples, noise, timesteps, condition)