mindflow.cell.ViT

View Source On Gitee
class mindflow.cell.ViT(image_size=(192, 384), in_channels=7, out_channels=3, patch_size=16, encoder_depths=12, encoder_embed_dim=768, encoder_num_heads=12, decoder_depths=8, decoder_embed_dim=512, decoder_num_heads=16, dropout_rate=0.0, compute_dtype=mstype.float16)[source]

This module based on ViT backbone which including encoder, decoding_embedding, decoder and dense layer.

Parameters
  • image_size (tuple[int]) – The image size of input. Default: (192, 384).

  • in_channels (int) – The input feature size of input. Default: 7.

  • out_channels (int) – The output feature size of output. Default: 3.

  • patch_size (int) – The patch size of image. Default: 16.

  • encoder_depths (int) – The encoder depth of encoder layer. Default: 12.

  • encoder_embed_dim (int) – The encoder embedding dimension of encoder layer. Default: 768.

  • encoder_num_heads (int) – The encoder heads' number of encoder layer. Default: 12.

  • decoder_depths (int) – The decoder depth of decoder layer. Default: 8.

  • decoder_embed_dim (int) – The decoder embedding dimension of decoder layer. Default: 512.

  • decoder_num_heads (int) – The decoder heads' number of decoder layer. Default: 16.

  • dropout_rate (float) – The rate of dropout layer. Default: 0.0.

  • compute_dtype (dtype) – The data type for encoder, decoding_embedding, decoder and dense layer. Default: mstype.float16.

Inputs:
  • input (Tensor) - Tensor of shape \((batch\_size, feature\_size, image\_height, image\_width)\).

Outputs:
  • output (Tensor) - Tensor of shape \((batch\_size, patchify\_size, embed\_dim)\). where patchify_size = (image_height * image_width) / (patch_size * patch_size)

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore import ops
>>> from mindflow.cell import ViT
>>> input_tensor = ops.rand(32, 3, 192, 384)
>>> print(input_tensor.shape)
(32, 3, 192, 384)
>>> model = ViT(in_channels=3,
>>>             out_channels=3,
>>>             encoder_depths=6,
>>>             encoder_embed_dim=768,
>>>             encoder_num_heads=12,
>>>             decoder_depths=6,
>>>             decoder_embed_dim=512,
>>>             decoder_num_heads=16,
>>>             )
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
(32, 288, 768)