mindflow.cell.ConditionDiffusionTransformer
- class mindflow.cell.ConditionDiffusionTransformer(in_channels, out_channels, cond_channels, hidden_channels, layers, heads, time_token_cond=True, cond_as_token=True, compute_dtype=mstype.float32)[source]
Conditioned Diffusion Transformer implementation.
- Parameters
in_channels (int) – The number of input channel.
out_channels (int) – The number of output channel.
hidden_channels (int) – The number of hidden channel.
cond_channels (int) – The number of condition channel.
layers (int) – The number of transformer block layers.
heads (int) – The number of transformer heads.
time_token_cond (bool) – Whether to use timestep as condition token. Default:
True
.cond_as_token (bool) – Whether to use condition as token. Default:
True
.compute_dtype (mindspore.dtype) – the dtype of compute, it can be
mstype.float32
ormstype.float16
. Default:mstype.float32
, indicatesmindspore.float32
.
- Inputs:
x (Tensor) - The input has a shape of
.timestep (Tensor) - The timestep input has a shape of
.condition (Tensor) - The condition input has a shape of
. Default:None
.
- Outputs:
output (Tensor) - The output has a shape of
.
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops >>> from mindflow.cell import ConditionDiffusionTransformer >>> in_channels, out_channels, cond_channels, hidden_channels = 16, 16, 10, 256 >>> layers, heads, batch_size, seq_len = 3, 4, 8, 256 >>> model = ConditionDiffusionTransformer(in_channels=in_channels, ... out_channels=out_channels, ... cond_channels=cond_channels, ... hidden_channels=hidden_channels, ... layers=layers, ... heads=heads) >>> x = ops.rand((batch_size, seq_len, in_channels)) >>> cond = ops.rand((batch_size, cond_channels)) >>> timestep = ops.randint(0, 1000, (batch_size,)) >>> output = model(x, timestep, cond) >>> print(output.shape) (8, 256, 16)