mindspore.numpy.take_along_axis

View Source On Gitee
mindspore.numpy.take_along_axis(arr, indices, axis)[source]

Takes values from the input array by matching 1d index and data slices.

This iterates over matching 1d slices oriented along the specified axis in the index and data arrays, and uses the former to look up values in the latter. These slices can be different lengths.

Parameters
  • arr (Tensor) – Source array with shape (Ni…, M, Nk…).

  • indices (Tensor) – Indices with shape (Ni…, J, Nk…) to take along each 1d slice of arr. This must match the dimension of arr, but dimensions Ni and Nj only need to broadcast against arr.

  • axis (int) – The axis to take 1d slices along. If axis is None, the input array is treated as if it had first been flattened to 1d.

Returns

Tensor, the indexed result, with shape (Ni…, J, Nk…).

Raises
  • ValueError – If input array and indices have different number of dimensions.

  • TypeError – If the input is not a Tensor.

Supported Platforms:

Ascend GPU CPU

Example

>>> import mindspore.numpy as np
>>> x = np.arange(12).reshape(3, 4)
>>> indices = np.arange(3).reshape(1, 3)
>>> output = np.take_along_axis(x, indices, 1)
>>> print(output)
[[ 0  1  2]
[ 4  5  6]
[ 8  9 10]]