mindflow.loss.WaveletTransformLoss
- class mindflow.loss.WaveletTransformLoss(wave_level=2, regroup=False)[source]
The multi-level wavelet transformation losses.
- Parameters
- Inputs:
input - tuple of Tensors. Tensor of shape \((B*H*W/(P*P), P*P*C)\), where B denotes the batch size. H, W denotes the height and the width of the image, respectively. P denotes the patch size. C denots the feature channels.
- Outputs:
Tensor.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> from mindflow.loss import WaveletTransformLoss >>> import mindspore >>> from mindspore import Tensor >>> net = WaveletTransformLoss(wave_level=2) >>> input1 = Tensor(np.ones((32, 288, 768)), mstype.float32) >>> input2 = Tensor(np.ones((32, 288, 768)), mstype.float32) >>> output = net((input1, input2)) >>> print(output) 2.0794415