mindspore.ops.TopK
- class mindspore.ops.TopK(sorted=True)[source]
Finds values and indices of the k largest entries along the last dimension.
Warning
If sorted is set to False, it will use the aicpu operator, the performance may be reduced. In addition, due to different memory layout and traversal methods on different platforms, the display order of calculation results may be inconsistent when sorted is False.
If the input_x is a one-dimensional Tensor, finds the k largest entries in the Tensor, and outputs its value and index as a Tensor. values[k] is the k largest item in input_x, and its index is indices [k].
For a multi-dimensional matrix, calculates the first k entries in each row (corresponding vector along the last dimension), therefore:
If the two compared elements are the same, the one with the smaller index value is returned first.
- Parameters
sorted (bool, optional) – If
True
, the obtained elements will be sorted by the values in descending order. IfFalse
, the obtained elements will not be sorted. Default:True
.
- Inputs:
input_x (Tensor) - Input to be computed, 0-D input is supported on GPU, but not on Ascend or CPU. supported dtypes:
Ascend: int8, uint8, int32, int64, float16, float32.
GPU: float16, float32.
CPU: all numeric types.
k (Union(Tensor, int)) - The number of top elements to be computed along the last dimension. If k is a Tensor, the supported dtype is int32 and it should be 0-D or 1-D with shape
.
- Outputs:
A tuple consisting of values and indexes.
values (Tensor) - The k largest elements in each slice of the last dimension.
indices (Tensor) - The indices of values within the last dimension of input.
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> from mindspore import Tensor >>> from mindspore import ops >>> import mindspore >>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16) >>> k = 3 >>> values, indices = ops.TopK(sorted=True)(input_x, k) >>> print((values, indices)) (Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00, 4.0000e+00, 3.0000e+00]), Tensor(shape=[3], dtype=Int32, value= [4, 3, 2]))