mindearth.cell.DgmrDiscriminator

查看源文件
class mindearth.cell.DgmrDiscriminator(in_channels=1, num_spatial_frames=8, conv_type='standard')[源代码]

Dgmr判别器基于时间判别器和空间判别器,其中包含深度残差块。 有关更多详细信息,请参考论文 Skilful precipitation nowcasting using deep generative models of radar

参数:
  • in_channels (int) - 输入中的通道数。默认值: 1

  • num_spatial_frames (int) - 待进行空间判别的时间步数。默认值: 8

  • conv_type (str) - 卷积核类型。默认值: standard

输入:
  • x (Tensor) - shape为 \((2, frames\_size, channels, height\_size, width\_size)\) 的Tensor。

输出:

Tensor,Dgmr Discriminator网络的输出。

  • output (Tensor) - shape为 \((2, 2, 1)\) 的Tensor。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import ops, Tensor
>>> from mindspore.nn import Cell
>>> from mindearth.cell.dgmr.dgmrnet import DgmrDiscriminator
>>> real_and_generator = np.random.rand(2, 22, 1, 256, 256).astype(np.float32)
>>> net = DgmrDiscriminator(in_channels=1, num_spatial_frames=8, con_type="standard")
>>> out = net(Tensor(real_and_generator, ms.float32))
>>> print(out.shape)
(2, 2, 1)