mindspore.nn.ResizeBilinear
- class mindspore.nn.ResizeBilinear[source]
Samples the input tensor to the given size or scale_factor by using bilinear interpolate.
- Inputs:
x (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape \((batch, channels, height, width)\), with data type of float16 or float32.
size (Union[tuple[int], list[int]]): A tuple or list of 2 int elements \((new\_height, new\_width)\),the new size of the tensor. One and only one of size and scale_factor can be set to None. Default: None.
scale_factor (int): The scale factor of new size of the tensor. The value should be positive integer. One and only one of size and scale_factor can be set to None. Default: None.
align_corners (bool): If true, rescale input by \((new\_height - 1) / (height - 1)\), which exactly aligns the 4 corners of images and resized images. If false, rescale by \(new\_height / height\). Default: False.
- Outputs:
Resized tensor. If size is set, the result is 4-D tensor with shape \((batch, channels, new\_height, new\_width)\), and the data type is the same as x. If scale is set, the result is 4-D tensor with shape \((batch, channels, scale\_factor * height, scale\_factor * width)\) and the data type is the same as x.
- Raises
TypeError – If size is not one of tuple, list, None.
TypeError – If scale_factor is neither int nor None.
TypeError – If align_corners is not a bool.
TypeError – If dtype of x is neither float16 nor float32.
ValueError – If size and scale_factor are both None or not None.
ValueError – If length of shape of x is not equal to 4.
ValueError – If scale_factor is an int which is less than 0.
ValueError – If size is a list or tuple whose length is not equal to 2.
- Supported Platforms:
Ascend
CPU
GPU
Examples
>>> x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32) >>> resize_bilinear = nn.ResizeBilinear() >>> result = resize_bilinear(x, size=(5,5)) >>> print(x) [[[[1. 2. 3. 4.] [5. 6. 7. 8.]]]] >>> print(result) [[[[1. 1.8 2.6 3.4 4. ] [2.6 3.4 4.2000003 5. 5.6000004] [4.2 5.0000005 5.8 6.6 7.2 ] [5. 5.8 6.6 7.4 8. ] [5. 5.8 6.6 7.4000006 8. ]]]] >>> print(result.shape) (1, 1, 5, 5)