mindflow.cell.DiffusionTransformer
- class mindflow.cell.DiffusionTransformer(in_channels, out_channels, hidden_channels, layers, heads, time_token_cond=True, compute_dtype=mstype.float32)[source]
Diffusion model with Transformer backbone implementation.
- Parameters
in_channels (int) – The number of input channel.
out_channels (int) – The number of output channel.
hidden_channels (int) – The number of hidden channel.
layers (int) – The number of transformer block layers.
heads (int) – The number of transformer heads.
time_token_cond (bool) – Whether to use timestep as condition token. Default:
True
.compute_dtype (mindspore.dtype) – The dtype of compute, it can be
mstype.float32
ormstype.float16
. Default:mstype.float32
, indicatesmindspore.float32
.
- Inputs:
x (Tensor) - The input has a shape of
.timestep (Tensor) - The timestep input has a shape of
.
- Outputs:
output (Tensor) - The output has a shape of
.
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops >>> from mindflow.cell import DiffusionTransformer >>> in_channels, out_channels, hidden_channels, layers, heads, batch_size, seq_len = 16, 16, 256, 3, 4, 8, 256 >>> model = DiffusionTransformer(in_channels=in_channels, ... out_channels=out_channels, ... hidden_channels=hidden_channels, ... layers=layers, ... heads=heads) >>> x = ops.rand((batch_size, seq_len, in_channels)) >>> timestep = ops.randint(0, 1000, (batch_size,)) >>> output = model(x, timestep) >>> print(output.shape) (8, 256, 16)