mindflow.cell.UNet2D

View Source On Gitee
class mindflow.cell.UNet2D(in_channels, out_channels, base_channels, n_layers=4, data_format='NHWC', kernel_size=2, stride=2, activation='relu', enable_bn=True)[source]

The 2-dimensional U-Net model. U-Net is a U-shaped convolutional neural network for biomedical image segmentation. It has a contracting path that captures context and an expansive path that enables precise localization. The details can be found in U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>.

Parameters
  • in_channels (int) – The number of input channels.

  • out_channels (int) – The number of output channels.

  • base_channels (int) – The number of base channels of UNet2D.

  • n_layers (int) – The number of downsample and upsample convolutions. Default: 4.

  • data_format (str) – The format of input data. Default: 'NHWC'

  • kernel_size (int) – Specifies the height and width of the 2D convolution kernel. Default: 2.

  • stride (Union[int, tuple[int]]) – The distance of kernel moving, an int number that represents the height and width of movement are both stride, or a tuple of two int numbers that represent height and width of movement respectively. Default: 2.

  • activation (Union[str, class]) – The activation function, could be either str or class. Default: relu.

  • enable_bn (bool) – Specifies whether to use batch norm in convolutions.

Inputs:
  • x (Tensor) - Tensor of shape \((batch\_size, resolution, resolution, channels)\).

Outputs:
  • output (Tensor) - Tensor of shape \((batch\_size, resolution, resolution, channels)\).

Supported Platforms:

Ascend GPU

Examples

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> import mindspore.common.dtype as mstype
>>> import mindflow
>>> from mindflow.cell import UNet2D
>>> ms.set_context(mode=ms.GRAPH_MODE, save_graphs=False, device_target="GPU")
>>> x=Tensor(np.ones([2, 128, 128, 3]), mstype.float32)
>>> unet = UNet2D(in_channels=3, out_channels=3, base_channels=3)
>>> output = unet(x)
>>> print(output.shape)
(2, 128, 128, 3)