mindearth.cell.DgmrGenerator

查看源文件
class mindearth.cell.DgmrGenerator(forecast_steps=18, in_channels=1, out_channels=256, conv_type='standard', latent_channels=768, context_channels=384, generation_steps=1)[源代码]

Dgmr 生成器基于 Conditional_Stack、Latent_Stack、Upsample_Stack 和 ConvGRU,其中包含深度残差块。 有关更多详细信息,请参考论文 Skilful precipitation nowcasting using deep generative models of radar

参数:
  • forecast_steps (int) - 待预测的步数。默认值: 18

  • in_channels (int) - 输入张量的通道数。默认值: 1

  • out_channels (int) - 输出张量的通道数。默认值: 256

  • conv_type (str) - 卷积核类型。默认值: standard

  • latent_channels (int) - 网络隐变量的通道数。默认值: 768

  • context_channels (int) - 上下文信息的通道数。默认值: 384

  • generation_steps (int) - 前向生成的样本数目。默认值: 1

输入:
  • x (Tensor) - shape为 \((batch\_size, input\_frames, channels, height\_size, width\_size)\) 的Tensor。

输出:

Tensor,Dgmr Generator网络的输出。

  • output (Tensor) - shape为 \((batch\_size, output\_frames, channels, height\_size, width\_size)\) 的Tensor。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import ops, Tensor
>>> from mindspore.nn import Cell
>>> from mindearth.cell.dgmr.dgmrnet import DgmrGenerator
>>> input_frames = np.random.rand(1, 4, 1, 256, 256).astype(np.float32)
>>> net = DgmrGenerator(
>>>         forecast_steps = 18,
>>>         in_channels = 1,
>>>         out_channels = 256,
>>>         conv_type = "standard",
>>>         latent_channels = 768,
>>>         context_channels = 384,
>>>         generation_steps = 1
>>>     )
>>> out = net(Tensor(input_frames, ms.float32))
>>> print(out.shape)
(1, 18, 1, 256, 256)