mindflow.cell.DiffusionTrainer

查看源文件
class mindflow.cell.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1, loss_type='l1')[源代码]

扩散模型训练控制实现。

参数:
  • model (nn.Cell) - 用于扩散模型训练的骨干网络。

  • scheduler (DiffusionScheduler) - 用于训练的噪声控制器。

  • objective (str) - 扩散模型预测结果的类型。默认值: pred_noise 。支持以下类型: pred_noise , pred_vpred_x0

  • p2_loss_weight_gamma (float) - p2 loss权重 gamma ,具体信息查看 Perception Prioritized Training of Diffusion Models 。默认值: 0.0

  • p2_loss_weight_k (float) - p2 loss权重 k ,具体信息查看 Perception Prioritized Training of Diffusion Models 。默认值: 1

  • loss_type (str) - loss函数类型。默认值: l1 。支持以下类型: l1l2

异常:
  • TypeError - 如果 scheduler 不是 DiffusionScheduler 类型。

支持平台:

Ascend

样例:

>>> from mindspore import ops, dtype as mstype
>>> from mindflow.cell import DDPMScheduler, ConditionDiffusionTransformer, DiffusionTrainer
>>> # init params
>>> batch_size, seq_len, in_dim, cond_dim, num_train_timesteps = 4, 256, 16, 4, 100
>>> original_samples = ops.randn([batch_size, seq_len, in_dim])
>>> noise = ops.randn([batch_size, seq_len, in_dim])
>>> timesteps = ops.randint(0, num_train_timesteps, [batch_size, 1])
>>> cond = ops.randn([batch_size, cond_dim])
>>> # init model and scheduler
>>> net = ConditionDiffusionTransformer(in_channels=in_dim,
...                                     out_channels=in_dim,
...                                     cond_channels=cond_dim,
...                                     hidden_channels=hidden_dim,
...                                     layers=layers,
...                                     heads=heads,
...                                     time_token_cond=True,
...                                     compute_dtype=mstype.float32)
>>> scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps,
...                           beta_start=0.0001,
...                           beta_end=0.02,
...                           beta_schedule="squaredcos_cap_v2",
...                           clip_sample=True,
...                           clip_sample_range=1.0,
...                           thresholding=False,
...                           dynamic_thresholding_ratio=0.995,
...                           rescale_betas_zero_snr=False,
...                           timestep_spacing="leading",
...                           compute_dtype=mstype.float32)
>>> # init trainer
>>> trainer = DiffusionTrainer(net,
...                            scheduler,
...                            objective='pred_noise',
...                            p2_loss_weight_gamma=0,
...                            p2_loss_weight_k=1,
...                            loss_type='l2')
>>> loss = trainer.get_loss(original_samples, noise, timesteps, condition)
get_loss(original_samples: Tensor, noise: Tensor, timesteps: Tensor, condition: Tensor = None)[源代码]

计算扩散过程的前向loss。

参数:
  • original_samples (Tensor) - 原始样本。

  • noise (Tensor) - 随机噪声。

  • timesteps (Tensor) - 时间步。

  • condition (Tensor) - 控制条件。默认值: None

返回:
  • Tensor - 前向loss。