mindearth.cell.DgmrDiscriminator
- class mindearth.cell.DgmrDiscriminator(in_channels=1, num_spatial_frames=8, conv_type='standard')[源代码]
Dgmr判别器基于时间判别器和空间判别器,其中包含深度残差块。 有关更多详细信息,请参考论文 Skilful precipitation nowcasting using deep generative models of radar 。
- 参数:
in_channels (int) - 输入中的通道数。默认值:
1
。num_spatial_frames (int) - 待进行空间判别的时间步数。默认值:
8
。conv_type (str) - 卷积核类型。默认值:
standard
。
- 输入:
x (Tensor) - shape为 \((2, frames\_size, channels, height\_size, width\_size)\) 的Tensor。
- 输出:
Tensor,Dgmr Discriminator网络的输出。
output (Tensor) - shape为 \((2, 2, 1)\) 的Tensor。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import ops, Tensor >>> from mindspore.nn import Cell >>> from mindearth.cell.dgmr.dgmrnet import DgmrDiscriminator >>> real_and_generator = np.random.rand(2, 22, 1, 256, 256).astype(np.float32) >>> net = DgmrDiscriminator(in_channels=1, num_spatial_frames=8, con_type="standard") >>> out = net(Tensor(real_and_generator, ms.float32)) >>> print(out.shape) (2, 2, 1)