mindspore.nn.ResizeBilinear

class mindspore.nn.ResizeBilinear[source]

Samples the input tensor to the given size or scale_factor by using bilinear interpolate.

Inputs:
  • x (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape (batch,channels,height,width), with data type of float16 or float32.

  • size (Union[tuple[int], list[int]]): A tuple or list of 2 int elements (new_height,new_width),the new size of the tensor. One and only one of size and scale_factor can be set to None. Default: None.

  • scale_factor (int): The scale factor of new size of the tensor. The value should be positive integer. One and only one of size and scale_factor can be set to None. Default: None.

  • align_corners (bool): If true, rescale input by (new_height1)/(height1), which exactly aligns the 4 corners of images and resized images. If false, rescale by new_height/height. Default: False.

Outputs:

Resized tensor. If size is set, the result is 4-D tensor with shape (batch,channels,new_height,new_width), and the data type is the same as x. If scale is set, the result is 4-D tensor with shape (batch,channels,scale_factorheight,scale_factorwidth) and the data type is the same as x.

Raises
  • TypeError – If size is not one of tuple, list, None.

  • TypeError – If scale_factor is neither int nor None.

  • TypeError – If align_corners is not a bool.

  • TypeError – If dtype of x is neither float16 nor float32.

  • ValueError – If size and scale_factor are both None or not None.

  • ValueError – If length of shape of x is not equal to 4.

  • ValueError – If scale_factor is an int which is less than 0.

  • ValueError – If size is a list or tuple whose length is not equal to 2.

Supported Platforms:

Ascend CPU GPU

Examples

>>> x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)
>>> resize_bilinear = nn.ResizeBilinear()
>>> result = resize_bilinear(x, size=(5,5))
>>> print(x)
[[[[1. 2. 3. 4.]
   [5. 6. 7. 8.]]]]
>>> print(result)
[[[[1.        1.8       2.6       3.4       4.       ]
   [2.6       3.4       4.2000003 5.        5.6000004]
   [4.2       5.0000005 5.8       6.6       7.2      ]
   [5.        5.8       6.6       7.4       8.       ]
   [5.        5.8       6.6       7.4000006 8.       ]]]]
>>> print(result.shape)
(1, 1, 5, 5)