mindearth.cell.DgmrGenerator
- class mindearth.cell.DgmrGenerator(forecast_steps=18, in_channels=1, out_channels=256, conv_type='standard', latent_channels=768, context_channels=384, generation_steps=1)[源代码]
Dgmr 生成器基于 Conditional_Stack、Latent_Stack、Upsample_Stack 和 ConvGRU,其中包含深度残差块。 有关更多详细信息,请参考论文 Skilful precipitation nowcasting using deep generative models of radar 。
- 参数:
forecast_steps (int) - 待预测的步数。默认值:
18
。in_channels (int) - 输入张量的通道数。默认值:
1
。out_channels (int) - 输出张量的通道数。默认值:
256
。conv_type (str) - 卷积核类型。默认值:
standard
。latent_channels (int) - 网络隐变量的通道数。默认值:
768
。context_channels (int) - 上下文信息的通道数。默认值:
384
。generation_steps (int) - 前向生成的样本数目。默认值:
1
。
- 输入:
x (Tensor) - shape为 \((batch\_size, input\_frames, channels, height\_size, width\_size)\) 的Tensor。
- 输出:
Tensor,Dgmr Generator网络的输出。
output (Tensor) - shape为 \((batch\_size, output\_frames, channels, height\_size, width\_size)\) 的Tensor。
- 支持平台:
Ascend
GPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import ops, Tensor >>> from mindspore.nn import Cell >>> from mindearth.cell.dgmr.dgmrnet import DgmrGenerator >>> input_frames = np.random.rand(1, 4, 1, 256, 256).astype(np.float32) >>> net = DgmrGenerator( >>> forecast_steps = 18, >>> in_channels = 1, >>> out_channels = 256, >>> conv_type = "standard", >>> latent_channels = 768, >>> context_channels = 384, >>> generation_steps = 1 >>> ) >>> out = net(Tensor(input_frames, ms.float32)) >>> print(out.shape) (1, 18, 1, 256, 256)