mindspore.ops.grid_sample

mindspore.ops.grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros', align_corners=False)[source]

Given an input_x and a flow-field grid, computes the output using input_x values and pixel locations from grid. Only spatial (4-D) and volumetric (5-D) input_x is supported.

In the spatial (4-D) case, for input_x with shape \((N, C, H_{in}, W_{in})\) and grid with shape \((N, H_{out}, W_{out}, 2)\), the output will have shape \((N, C, H_{out}, W_{out})\).

For each output location output[n, :, h, w], the size-2 vector grid[n, h, w] specifies input_x pixel locations x and y, which are used to interpolate the output value output[n, :, h, w]. In the case of 5D inputs, grid[n, d, h, w], specifies the x, y, z pixel locations for interpolating output[n, :, d, h, w]. And interpolation_mode argument specifies “nearest” or “bilinear” or “bicubic” (supported in 4D case only) interpolation method to sample the input pixels.

grid specifies the sampling pixel locations normalized by the input_x spatial dimensions. Therefore, it should have most values in the range of \([-1, 1]\).

If grid has values outside the range of \([-1, 1]\), the corresponding outputs are handled as defined by padding_mode. If padding_mode is set to be “zeros”, use \(0\) for out-of-bound grid locations. If padding_mode is set to be “border”, use border values for out-of-bound grid locations. If padding_mode is set to be “reflection”, use values at locations reflected by the border for out-of-bound grid locations. For location far away from the border, it will keep being reflected until becoming in bound.

Parameters
  • input_x (Tensor) – input with shape of \((N, C, H_{in}, W_{in})\) (4-D case) or \((N, C, D_{in}, H_{in}, W_{in})\) (5-D case) and dtype of float32 or float64.

  • grid (Tensor) – flow-field with shape of \((N, H_{out}, W_{out}, 2)\) (4-D case) or \((N, D_{out}, H_{out}, W_{out}, 3)\) (5-D case) and same dtype as input_x.

  • interpolation_mode (str) – An optional string specifying the interpolation method. The optional values are “bilinear”, “nearest” or “bicubic”. Default: “bilinear”. Note: bicubic supports only 4-D input. When interpolation_mode=”bilinear” and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legistimately be bilinear.

  • padding_mode (str) – An optional string specifying the pad method. The optional values are “zeros”, “border” or “reflection”. Default: “zeros”.

  • align_corners (bool) – An optional bool. If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. Default: False.

Returns

Tensor, dtype is the same as input_x and whose shape is \((N, C, H_{out}, W_{out})\) (4-D) and \((N, C, D_{out}, H_{out}, W_{out})\) (5-D).

Raises
  • TypeError – If input_x or grid is not a Tensor.

  • TypeError – If the dtypes of input_x and grid are inconsistent.

  • TypeError – If the dtype of input_x or grid is not a valid type.

  • TypeError – If align_corners is not a boolean value.

  • ValueError – If the rank of input_x or grid is not equal to 4(4-D case) or 5(5-D case).

  • ValueError – If the first dimension of input_x is not equal to that of grid.

  • ValueError – If the last dimension of grid is not equal to 2(4-D case) or 3(5-D case).

  • ValueError – If interpolation_mode is not “bilinear”, “nearest”, “bicubic” or a string value.

  • ValueError – If padding_mode is not “zeros”, “border”, “reflection” or a string value.

Supported Platforms:

Ascend GPU CPU

Examples

>>> input_x = Tensor(np.arange(16).reshape((2, 2, 2, 2)).astype(np.float32))
>>> grid = Tensor(np.arange(0.2, 1, 0.1).reshape((2, 2, 1, 2)).astype(np.float32))
>>> output = ops.grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros',
...                          align_corners=True)
>>> print(output)
[[[[ 1.9      ]
   [ 2.1999998]]
  [[ 5.9      ]
   [ 6.2      ]]]
 [[[10.5      ]
   [10.8      ]]
  [[14.5      ]
   [14.8      ]]]]