mindspore.nn.ResizeBilinear

class mindspore.nn.ResizeBilinear[source]

Samples the input tensor to the given size or scale_factor by using bilinear interpolate.

Inputs:
  • x (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape: math:(batch, channels, height, width), with data type of float16 or float32.

  • size (Union[tuple[int], list[int]]): A tuple or list of 2 int elements ‘(new_height, new_width)’, the new size of the tensor. One and only one of size and scale_factor can be set to None. Default: None.

  • scale_factor (int): The scale factor of new size of the tensor. The value should be positive integer. One and only one of size and scale_factor can be set to None. Default: None.

  • align_corners (bool): If true, rescale input by ‘(new_height - 1) / (height - 1)’, which exactly aligns the 4 corners of images and resized images. If false, rescale by ‘new_height / height’. Default: False.

Outputs:

Resized tensor. If size is set, the result is 4-D tensor with shape:math:(batch, channels, new_height, new_width) in float32. If scale is set, the result is 4-D tensor with shape:math:(batch, channels, scale_factor * height, scale_factor * width) in float32

Raises
  • TypeError – If size is not one of tuple, list, None.

  • TypeError – If scale_factor is neither int nor None.

  • TypeError – If align_corners is not a bool.

  • TypeError – If dtype of x is neither float16 nor float32.

  • ValueError – If size and scale_factor are both None or not None.

  • ValueError – If length of shape of x is not equal to 4.

  • ValueError – If scale_factor is an int which is less than 1.

  • ValueError – If size is a list or tuple whose length is not equal to 2.

Supported Platforms:

Ascend CPU

Examples

>>> x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)
>>> resize_bilinear = nn.ResizeBilinear()
>>> result = resize_bilinear(x, size=(5,5))
>>> print(x)
[[[[1. 2. 3. 4.]
   [5. 6. 7. 8.]]]]
>>> print(result)
[[[[1.        1.8.      2.6       3.4       4.       ]
   [2.6       3.4       4.2000003 5.        5.6000004]
   [4.2       5.0000005 5.8       6.6       7.2      ]
   [5.        5.8       6.6       7.4       8.       ]
   [5.        5.8       6.6       7.4000006 8.       ]]]]
>>> print(result.shape)
(1, 1, 5, 5)