比较与torch.argsort的功能差异

查看源文件

torch.argsort

class torch.argsort(
    input,
    dim=-1,
    descending=False
)

更多内容详见 torch.argsort

mindspore.ops.Sort

class mindspore.ops.Sort(
    axis=-1,
    descending=False
)(x)

更多内容详见 mindspore.ops.Sort

使用方式

PyTorch: 返回按值升序沿给定维度对张量进行排序的索引。

MindSpore: 按值升序沿给定维度对输入张量的元素进行排序。 返回一个张量,其值为排序后的值,以及原始输入张量中元素的索引。

代码示例

import numpy as np
import torch
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype

# MindSpore
x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mstype.float16)
sort = ops.Sort()
output = sort(x)
print(output)
# Out:
# (Tensor(shape=[3, 3], dtype=Float16, value=
# [[ 1.0000e+00,  2.0000e+00,  8.0000e+00],
#  [ 3.0000e+00,  5.0000e+00,  9.0000e+00],
#  [ 4.0000e+00,  6.0000e+00,  7.0000e+00]]), Tensor(shape=[3, 3], dtype=Int32, value=
# [[2, 1, 0],
#  [2, 0, 1],
#  [0, 1, 2]]))

# Pytorch
a = torch.tensor([[8, 2, 1], [5, 9, 3], [4, 6, 7]], dtype=torch.int8)
torch.argsort(a, dim=1)
# Out:
# tensor([[2, 1, 0],
#         [2, 0, 1],
#         [0, 1, 2]])