mindflow.loss.WaveletTransformLoss

View Source On Gitee
class mindflow.loss.WaveletTransformLoss(wave_level=2, regroup=False)[source]

The multi-level wavelet transformation losses.

Parameters
  • wave_level (int) – The number of the wavelet transformation levels, should be positive integer.

  • regroup (bool) – The regroup error combination form of the wavelet transformation losses. Default: "False".

Inputs:
  • input - tuple of Tensors. Tensor of shape \((B*H*W/(P*P), P*P*C)\), where B denotes the batch size. H, W denotes the height and the width of the image, respectively. P denotes the patch size. C denots the feature channels.

Outputs:

Tensor.

Raises
Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindflow.loss import WaveletTransformLoss
>>> import mindspore
>>> from mindspore import Tensor
>>> net = WaveletTransformLoss(wave_level=2)
>>> input1 = Tensor(np.ones((32, 288, 768)), mstype.float32)
>>> input2 = Tensor(np.ones((32, 288, 768)), mstype.float32)
>>> output = net((input1, input2))
>>> print(output)
2.0794415