mindspore.numpy

MindSpore NumPy工具包提供了一系列类NumPy接口。用户可以使用类NumPy语法在MindSpore上进行模型的搭建。

MindSpore Numpy具有四大功能模块:Array生成、Array操作、逻辑运算和数学运算。

在API示例中,常用的模块导入方法如下:

import mindspore.numpy as np

说明

MindSpore numpy通过组装底层算子来提供与numpy一致的编程体验接口,方便开发人员使用和代码移植。相比于MindSpore的function和ops接口,与原始numpy的接口格式及行为一致性更好,以便于用户理解和使用。注意:由于兼容numpy的考虑,部分接口的性能可能弱于function和ops接口。使用者可以按需选择不同类型的接口。

Array生成

生成类算子用来生成和构建具有指定数值、类型和形状的数组(Tensor)。

构建数组代码示例:

import mindspore.numpy as np
import mindspore.ops as ops
input_x = np.array([1, 2, 3], np.float32)
print("input_x =", input_x)
print("type of input_x =", ops.typeof(input_x))

运行结果如下:

input_x = [1. 2. 3.]
type of input_x = Tensor[Float32]

除了使用上述方法来创建外,也可以通过以下几种方式创建。

  • 生成具有相同元素的数组

    生成具有相同元素的数组代码示例:

    input_x = np.full((2, 3), 6, np.float32)
    print(input_x)
    

    运行结果如下:

    [[6. 6. 6.]
     [6. 6. 6.]]
    

    生成指定形状的全1数组,示例:

    input_x = np.ones((2, 3), np.float32)
    print(input_x)
    

    运行结果如下:

    [[1. 1. 1.]
     [1. 1. 1.]]
    
  • 生成具有某个范围内的数值的数组

    生成指定范围内的等差数组代码示例:

    input_x = np.arange(0, 5, 1)
    print(input_x)
    

    运行结果如下:

    [0 1 2 3 4]
    
  • 生成特殊类型的数组

    生成给定对角线处下方元素为1,上方元素为0的矩阵,示例:

    input_x = np.tri(3, 3, 1)
    print(input_x)
    

    运行结果如下:

    [[1. 1. 0.]
     [1. 1. 1.]
     [1. 1. 1.]]
    

    生成对角线为1,其他元素为0的二维矩阵,示例:

    input_x = np.eye(2, 2)
    print(input_x)
    

    运行结果如下:

    [[1. 0.]
     [0. 1.]]
    

接口名

概述

支持平台

mindspore.numpy.arange

返回给定区间内均匀间隔的值。

Ascend GPU CPU

mindspore.numpy.array

该函数接受一个类似数组的对象创建Tensor。

Ascend GPU CPU

mindspore.numpy.asarray

该函数将一个类似数组的对象转换为Tensor。

Ascend GPU CPU

mindspore.numpy.asfarray

类似于 asarray ,将输入转换为float tensor。

Ascend GPU CPU

mindspore.numpy.bartlett

用于生成Bartlett窗口。

Ascend GPU CPU

mindspore.numpy.blackman

用于生成Blackman窗口。

Ascend GPU CPU

mindspore.numpy.copy

返回给定对象的Tensor副本。

Ascend GPU CPU

mindspore.numpy.diag

用于提取或构造对角线数组。

Ascend GPU CPU

mindspore.numpy.diag_indices

返回一个可以访问数组的主对角线的索引数组。

Ascend GPU CPU

mindspore.numpy.diagflat

返回一个二维数组,其数组输入作为新输出数组的对角线。

Ascend GPU CPU

mindspore.numpy.diagonal

返回数组指定的对角线。

Ascend GPU CPU

mindspore.numpy.empty

返回一个给定shape和类型的新数组,而不进行初始化。

Ascend GPU CPU

mindspore.numpy.empty_like

返回一个shape和类型与给定数组相同的新数组。

Ascend GPU CPU

mindspore.numpy.eye

返回一个对角线上值为1,其他位置为0的二维Tensor。

Ascend GPU CPU

mindspore.numpy.full

返回一个给定shape、类型,并用 fill_value 填充的新数组。

Ascend GPU CPU

mindspore.numpy.full_like

返回一个与给定数组具有相同shape和类型的完整数组。

Ascend GPU CPU

mindspore.numpy.geomspace

返回在对数刻度(几何级数)上均匀间隔的数字。

Ascend GPU CPU

mindspore.numpy.hamming

返回一个Hamming窗口函数。

Ascend GPU CPU

mindspore.numpy.hanning

返回一个Hanning窗口函数。

Ascend GPU CPU

mindspore.numpy.histogram_bin_edges

计算 histogram 函数需要使用的 bins 的边界值。

Ascend GPU CPU

mindspore.numpy.identity

返回单位数组。

Ascend GPU CPU

mindspore.numpy.indices

返回一个表示网格索引的数组。

Ascend GPU CPU

mindspore.numpy.ix_

从多个序列构建一个开放式网格,用于索引的坐标数组。

Ascend GPU CPU

mindspore.numpy.linspace

返回给定区间内均匀间隔的值。

Ascend GPU CPU

mindspore.numpy.logspace

返回在对数刻度上均匀间隔的值。

Ascend GPU CPU

mindspore.numpy.meshgrid

返回由坐标向量生成的坐标矩阵。

Ascend GPU CPU

mindspore.numpy.mgrid

返回一个密集矩阵 NdGrid 实例,其中 sparse=False

Ascend GPU CPU

mindspore.numpy.ogrid

返回一个稀疏矩阵 NdGrid 实例,其中 sparse=True

Ascend GPU CPU

mindspore.numpy.ones

返回一个给定shape和类型的新Tensor,其中所有元素用1来填充。

Ascend GPU CPU

mindspore.numpy.ones_like

返回一个与给定数组 a 具有相同shape和类型的Tensor,其中所有元素用1来填充。

Ascend GPU CPU

mindspore.numpy.pad

对矩阵进行填充。

Ascend GPU CPU

mindspore.numpy.rand

返回一个给定shape和类型的新Tensor,其中所有元素以区间 \([0,1)\) 上均匀分布的随机数来填充。

Ascend GPU CPU

mindspore.numpy.randint

返回从 minval (包括)到 maxval (不包括)的随机整数。

Ascend GPU CPU

mindspore.numpy.randn

返回一个给定shape和类型的新Tensor,并填充来自标准正态分布中的一个(或多个)样本。

Ascend GPU CPU

mindspore.numpy.trace

返回张量沿对角线的元素之和。

Ascend GPU CPU

mindspore.numpy.tri

返回一个Tensor,在给定的对角线处及以下元素值为1,在其他位置为0。

Ascend GPU CPU

mindspore.numpy.tril

返回张量的下三角部分。

Ascend GPU CPU

mindspore.numpy.tril_indices

返回shape为 (n, m) 数组的下三角的索引。

Ascend GPU CPU

mindspore.numpy.tril_indices_from

返回 arr 下三角的索引。

Ascend GPU CPU

mindspore.numpy.triu

返回张量的上三角部分。

Ascend GPU CPU

mindspore.numpy.triu_indices

返回shape为 (n, m) 数组的上三角的索引。

Ascend GPU CPU

mindspore.numpy.triu_indices_from

返回 arr 上三角的索引。

Ascend GPU CPU

mindspore.numpy.vander

生成一个范德蒙德矩阵。

Ascend GPU CPU

mindspore.numpy.zeros

返回一个给定shape和类型的新Tensor,其中所有元素以0来填充。

Ascend GPU CPU

mindspore.numpy.zeros_like

返回一个与给定数组 a 具有相同shape和类型的Tensor,其中所有元素用0来填充。

Ascend GPU CPU

Array操作

操作类算子主要进行数组的维度变换,分割和拼接等。

  • 数组维度变换

    矩阵转置,代码示例:

    input_x = np.arange(10).reshape(5, 2)
    output = np.transpose(input_x)
    print(output)
    

    运行结果如下:

    [[0 2 4 6 8]
     [1 3 5 7 9]]
    

    交换指定轴,代码示例:

    input_x = np.ones((1, 2, 3))
    output = np.swapaxes(input_x, 0, 1)
    print(output.shape)
    

    运行结果如下:

    (2, 1, 3)
    
  • 数组分割

    将输入数组平均切分为多个数组,代码示例:

    input_x = np.arange(9)
    output = np.split(input_x, 3)
    print(output)
    

    运行结果如下:

    (Tensor(shape=[3], dtype=Int32, value= [0, 1, 2]), Tensor(shape=[3], dtype=Int32, value= [3, 4, 5]), Tensor(shape=[3], dtype=Int32, value= [6, 7, 8]))
    
  • 数组拼接

    将两个数组按照指定轴进行拼接,代码示例:

    input_x = np.arange(0, 5)
    input_y = np.arange(10, 15)
    output = np.concatenate((input_x, input_y), axis=0)
    print(output)
    

    运行结果如下:

    [ 0  1  2  3  4 10 11 12 13 14]
    

接口名

概述

支持平台

mindspore.numpy.append

将值添加到Tensor的末尾。

Ascend GPU CPU

mindspore.numpy.apply_along_axis

在指定轴的一维切片上调用给定函数。

Ascend GPU CPU

mindspore.numpy.apply_over_axes

在多个轴上重复应用 func 函数。

Ascend GPU CPU

mindspore.numpy.argwhere

返回Tensor中非零元素的索引,并按元素分组。

Ascend GPU CPU

mindspore.numpy.array_split

将一个Tensor切分为多个Sub-Tensors。

Ascend GPU CPU

mindspore.numpy.array_str

返回数组中数据的字符串表示形式。

Ascend GPU CPU

mindspore.numpy.atleast_1d

将输入转换为至少一维的数组。

Ascend GPU CPU

mindspore.numpy.atleast_2d

将输入重新调整为至少二维的数组。

Ascend GPU CPU

mindspore.numpy.atleast_3d

将输入重新调整为至少三维的数组。

Ascend GPU CPU

mindspore.numpy.broadcast_arrays

将任意数量的数组广播到共同的shape。

Ascend GPU CPU

mindspore.numpy.broadcast_to

将数组广播到新shape。

Ascend GPU CPU

mindspore.numpy.choose

从索引数组和要选择的数组列表构造一个新数组。

Ascend GPU CPU

mindspore.numpy.column_stack

将一维Tensor作为列堆叠为二维Tensor。

Ascend GPU CPU

mindspore.numpy.concatenate

沿现有轴连接一系列Tensor。

Ascend GPU CPU

mindspore.numpy.dsplit

沿第三轴(深度)将Tensor分割为多个sub-tensor。

Ascend GPU CPU

mindspore.numpy.dstack

按顺序在深度方向(沿第三轴)堆叠Tensor。

Ascend GPU CPU

mindspore.numpy.expand_dims

扩展Tensor的shape。

Ascend GPU CPU

mindspore.numpy.flip

沿给定轴反转数组中的元素顺序。

GPU CPU

mindspore.numpy.fliplr

在左右方向上翻转每行中的元素。

GPU CPU

mindspore.numpy.flipud

在上下方向上翻转每列中的元素。

GPU CPU

mindspore.numpy.hsplit

水平(按列)将Tensor分割为多个sub-tensor。

Ascend GPU CPU

mindspore.numpy.hstack

按顺序水平堆叠Tensor。

Ascend GPU CPU

mindspore.numpy.intersect1d

查找两个Tensor的交集。

Ascend GPU CPU

mindspore.numpy.moveaxis

将数组的轴移动到新位置。

Ascend GPU CPU

mindspore.numpy.piecewise

执行分段定义的函数。

Ascend GPU CPU

mindspore.numpy.ravel

返回一个连续的展平Tensor。

Ascend GPU CPU

mindspore.numpy.repeat

重复数组的元素。

Ascend GPU CPU

mindspore.numpy.reshape

在不改变数据的情况下重塑一个Tensor。

Ascend GPU CPU

mindspore.numpy.roll

将Tensor沿给定的轴进行滚动。

Ascend GPU CPU

mindspore.numpy.rollaxis

将指定轴向后滚动,直到它位于给定的位置。

Ascend GPU CPU

mindspore.numpy.rot90

将Tensor在指定的轴平面内旋转90度。

GPU

mindspore.numpy.select

根据条件从 choicelist 中的元素中返回数组。

Ascend GPU CPU

mindspore.numpy.setdiff1d

计算两个Tensor的差集。

Ascend GPU CPU

mindspore.numpy.size

返回沿给定轴的元素数量。

Ascend GPU CPU

mindspore.numpy.split

将一个Tensor沿指定轴分割为多个sub-tensor。

Ascend GPU CPU

mindspore.numpy.squeeze

从Tensor的shape中移除单维元素。

Ascend GPU CPU

mindspore.numpy.stack

沿新轴连接一系列数组。

Ascend GPU CPU

mindspore.numpy.swapaxes

交换Tensor的两个轴。

Ascend GPU CPU

mindspore.numpy.take

从数组中沿指定轴提取元素。

Ascend GPU CPU

mindspore.numpy.take_along_axis

根据一维索引和数据切片从输入数组中提取值。

Ascend GPU CPU

mindspore.numpy.tile

通过重复 a 指定次数构造一个数组,次数由 reps 给出。

Ascend GPU CPU

mindspore.numpy.transpose

反转或交换Tensor的轴;返回修改后的Tensor。

Ascend GPU CPU

mindspore.numpy.unique

返回去重后的Tensor元素。

Ascend GPU CPU

mindspore.numpy.unravel_index

将一维索引或一维索引数组转换为坐标数组的tuple。

Ascend GPU CPU

mindspore.numpy.vsplit

垂直(按行)将Tensor分割为多个sub-tensor。

Ascend GPU CPU

mindspore.numpy.vstack

按顺序垂直堆叠Tensor。

Ascend GPU CPU

mindspore.numpy.where

根据 conditionxy 中选择元素。

Ascend GPU CPU

逻辑运算

逻辑运算类算子主要进行各类逻辑相关的运算。

相等(equal)和小于(less)计算代码示例如下:

input_x = np.arange(0, 5)
input_y = np.arange(0, 10, 2)
output = np.equal(input_x, input_y)
print("output of equal:", output)
output = np.less(input_x, input_y)
print("output of less:", output)

运行结果如下:

output of equal: [ True False False False False]
output of less: [False  True  True  True  True]

接口名

概述

支持平台

mindspore.numpy.array_equal

当输入数组shape相同且所有元素相等时,返回 True

GPU CPU Ascend

mindspore.numpy.array_equiv

当输入数组shape一致且所有元素相等时,返回 True

Ascend GPU CPU

mindspore.numpy.equal

逐元素返回 \((x1 == x2)\) 的真值。

Ascend GPU CPU

mindspore.numpy.greater

逐元素返回 \((x1 > x2)\) 的真值。

Ascend GPU CPU

mindspore.numpy.greater_equal

逐元素返回 \((x1 >= x2)\) 的真值。

Ascend GPU CPU

mindspore.numpy.in1d

测试一维数组的每个元素是否也存在于第二个数组中。

Ascend GPU CPU

mindspore.numpy.isclose

返回一个bool类型的Tensor,用于表示两个Tensor在给定的容差范围内是否逐元素相等。

Ascend GPU CPU

mindspore.numpy.isfinite

逐元素测试是否为有限数(不是无穷大或非数值)。

Ascend GPU CPU

mindspore.numpy.isin

test_elements 中的元素上计算,并仅在 element 上进行广播。

Ascend GPU CPU

mindspore.numpy.isinf

逐元素测试是否为正无穷大或负无穷大。

GPU CPU

mindspore.numpy.isnan

逐元素测试是否为NaN,并将结果返回为bool数组。

GPU CPU

mindspore.numpy.isneginf

逐元素测试是否为负无穷大,并将结果返回为bool数组。

GPU CPU

mindspore.numpy.isposinf

逐元素测试是否为正无穷大,并将结果返回为bool数组。

GPU CPU

mindspore.numpy.isscalar

如果元素的类型是标量类型,则返回True。

Ascend GPU CPU

mindspore.numpy.less

逐元素返回 \((x1 < x2)\) 的真值。

Ascend GPU CPU

mindspore.numpy.less_equal

逐元素返回 \((x1 <= x2)\) 的真值。

Ascend GPU CPU

mindspore.numpy.logical_and

逐元素计算 x1x2 的逻辑与(AND)的真值。

Ascend GPU CPU

mindspore.numpy.logical_not

逐元素计算 a 的逻辑非(NOT)的真值。

Ascend GPU CPU

mindspore.numpy.logical_or

逐元素计算 x1x2 的逻辑或(OR)的真值。

Ascend GPU CPU

mindspore.numpy.logical_xor

逐元素计算 x1x2 的逻辑异或(XOR)的真值。

Ascend GPU CPU

mindspore.numpy.not_equal

逐元素返回 (x1 != x2) 的真值。

Ascend GPU CPU

mindspore.numpy.signbit

逐元素扫描元素的符号位,如果符号位为1(即元素小于0)则返回True。

Ascend GPU CPU

mindspore.numpy.sometrue

测试沿给定轴是否有任意数组元素为True。

Ascend GPU CPU

数学运算

数学运算类算子包括各类数学相关的运算:加减乘除乘方,以及指数、对数等常见函数等。

数学计算支持类似NumPy的广播特性。

  • 加法

    以下代码实现了 input_xinput_y 两数组相加的操作:

    input_x = np.full((3, 2), [1, 2])
    input_y = np.full((3, 2), [3, 4])
    output = np.add(input_x, input_y)
    print(output)
    

    运行结果如下:

    [[4 6]
     [4 6]
     [4 6]]
    
  • 矩阵乘法

    以下代码实现了 input_xinput_y 两矩阵相乘的操作:

    input_x = np.arange(2*3).reshape(2, 3).astype('float32')
    input_y = np.arange(3*4).reshape(3, 4).astype('float32')
    output = np.matmul(input_x, input_y)
    print(output)
    

    运行结果如下:

    [[20. 23. 26. 29.]
     [56. 68. 80. 92.]]
    
  • 求平均值

    以下代码实现了求 input_x 所有元素的平均值的操作:

    input_x = np.arange(6).astype('float32')
    output = np.mean(input_x)
    print(output)
    

    运行结果如下:

    2.5
    
  • 指数

    以下代码实现了自然常数 einput_x 次方的操作:

    input_x = np.arange(5).astype('float32')
    output = np.exp(input_x)
    print(output)
    

    运行结果如下:

    [ 1.         2.7182817  7.389056  20.085537  54.59815  ]
    

接口名

概述

支持平台

mindspore.numpy.absolute

逐元素计算绝对值。

Ascend GPU CPU

mindspore.numpy.add

逐元素相加两个参数。

Ascend GPU CPU

mindspore.numpy.amax

返回数组的最大值或沿指定轴的最大值。

Ascend GPU CPU

mindspore.numpy.amin

返回数组的最小值或沿指定轴的最小值。

Ascend GPU CPU

mindspore.numpy.arccos

逐元素计算反余弦函数。

Ascend GPU CPU

mindspore.numpy.arccosh

逐元素计算反双曲余弦函数。

Ascend GPU CPU

mindspore.numpy.arcsin

逐元素计算反正弦函数。

Ascend GPU CPU

mindspore.numpy.arcsinh

逐元素计算反双曲正弦函数。

Ascend GPU CPU

mindspore.numpy.arctan

逐元素计算反正切函数。

Ascend GPU CPU

mindspore.numpy.arctan2

逐元素计算 \(x1/x2\) 的反正切,并正确选择象限。

Ascend GPU CPU

mindspore.numpy.arctanh

逐元素计算反双曲正切。

Ascend CPU

mindspore.numpy.argmax

返回沿指定轴最大值的索引。

Ascend GPU CPU

mindspore.numpy.argmin

返回沿指定轴最小值的索引。

Ascend GPU CPU

mindspore.numpy.around

向给定的小数位数四舍五入。

Ascend GPU CPU

mindspore.numpy.average

沿指定轴计算加权平均值。

Ascend GPU CPU

mindspore.numpy.bincount

计算数组中非负int值的出现次数。

Ascend GPU CPU

mindspore.numpy.bitwise_and

逐元素计算两个数组的按位与运算。

Ascend CPU

mindspore.numpy.bitwise_or

逐元素计算两个数组的按位或运算。

Ascend CPU

mindspore.numpy.bitwise_xor

逐元素计算两个数组的按位异或运算。

Ascend CPU

mindspore.numpy.cbrt

返回Tensor的立方根,逐元素计算。

Ascend GPU CPU

mindspore.numpy.ceil

返回输入的向上取整结果,逐元素计算。

Ascend GPU CPU

mindspore.numpy.clip

裁剪(限制)数组的值。

Ascend GPU CPU

mindspore.numpy.convolve

返回两个一维序列的离散线性卷积。

GPU CPU

mindspore.numpy.copysign

x1 的符号更改为 x2 的符号,逐元素执行。

Ascend GPU CPU

mindspore.numpy.corrcoef

返回皮尔逊积矩相关系数。

Ascend GPU CPU

mindspore.numpy.correlate

两个一维序列的互相关。

Ascend GPU CPU

mindspore.numpy.cos

逐元素计算余弦。

Ascend GPU CPU

mindspore.numpy.cosh

逐元素计算双曲余弦。

Ascend CPU

mindspore.numpy.count_nonzero

计算Tensor x 中的非零值数量。

Ascend GPU CPU

mindspore.numpy.cov

给定数据和权重,估算一个协方差矩阵。

Ascend GPU CPU

mindspore.numpy.cross

返回两个向量(数组)的叉积。

Ascend GPU CPU

mindspore.numpy.cumprod

返回沿给定 axis 的元素的累计乘积。

Ascend GPU

mindspore.numpy.cumsum

返回沿给定 axis 的元素的累计和。

Ascend GPU CPU

mindspore.numpy.deg2rad

将角度从角度制转换为弧度制。

Ascend GPU CPU

mindspore.numpy.diff

计算沿给定 axis 的n阶离散差分。

Ascend GPU CPU

mindspore.numpy.digitize

返回输入数组中每个值所属的桶的索引。

Ascend GPU CPU

mindspore.numpy.divide

返回输入的真除法,逐元素计算。

Ascend GPU CPU

mindspore.numpy.divmod

同时返回逐元素的商和余数。

Ascend GPU CPU

mindspore.numpy.dot

返回两个数组的点积。

Ascend GPU CPU

mindspore.numpy.ediff1d

计算Tensor中连续元素之间的差值。

Ascend GPU CPU

mindspore.numpy.exp

计算输入数组中所有元素的指数。

Ascend GPU CPU

mindspore.numpy.exp2

计算输入数组中所有值 p2**p

Ascend GPU CPU

mindspore.numpy.expm1

计算数组中所有元素的 exp(x) - 1

Ascend GPU CPU

mindspore.numpy.fix

舍入至最接近零的相邻整数。

Ascend GPU CPU

mindspore.numpy.float_power

第一个数组逐元素计算幂次方,指数为第二个数组中对应的元素。

Ascend GPU CPU

mindspore.numpy.floor

返回输入的向下取整,逐元素计算。

Ascend GPU CPU

mindspore.numpy.floor_divide

对输入进行除法运算,返回小于等于除法结果的最大整数。

Ascend GPU CPU

mindspore.numpy.fmod

返回除法的逐元素余数。

Ascend GPU CPU

mindspore.numpy.gcd

返回 |x1||x2| 的最大公约数。

Ascend GPU CPU

mindspore.numpy.gradient

返回一个N维数组的梯度。

Ascend GPU CPU

mindspore.numpy.heaviside

计算Heaviside阶跃函数。

Ascend GPU CPU

mindspore.numpy.histogram

计算数据集的直方图。

Ascend GPU CPU

mindspore.numpy.histogram2d

计算数据的二维直方图。

Ascend GPU CPU

mindspore.numpy.histogramdd

计算数据的多维直方图。

Ascend GPU CPU

mindspore.numpy.hypot

给定直角三角形的直角边,返回其斜边。

Ascend GPU CPU

mindspore.numpy.inner

返回两个Tensor的内积。

Ascend GPU CPU

mindspore.numpy.interp

用于单调递增的样本点的一维线性插值。

Ascend GPU CPU

mindspore.numpy.invert

逐元素计算按位取反或按位非。

Ascend

mindspore.numpy.kron

两个数组的Kronecker积。

Ascend GPU CPU

mindspore.numpy.lcm

返回 |x1||x2| 的最小公倍数。

Ascend GPU CPU

mindspore.numpy.log

返回自然对数,逐元素计算。

Ascend GPU CPU

mindspore.numpy.log10

x 的以10为底的对数。

Ascend GPU CPU

mindspore.numpy.log1p

返回1加上输入数组的自然对数,逐元素计算。

Ascend GPU CPU

mindspore.numpy.log2

x 的以2为底的对数。

Ascend GPU CPU

mindspore.numpy.logaddexp

计算输入指数取幂的和的对数。

Ascend GPU CPU

mindspore.numpy.logaddexp2

计算输入指数取幂的和的以2为底对数。

Ascend GPU CPU

mindspore.numpy.matmul

返回两个数组的矩阵乘积。

Ascend GPU CPU

mindspore.numpy.matrix_power

计算方阵以整数 n 为指数的幂。

Ascend GPU CPU

mindspore.numpy.maximum

逐元素比较两个数组,返回每对数组元素中的最大值。

Ascend GPU CPU

mindspore.numpy.mean

沿指定轴计算算术平均值。

Ascend GPU CPU

mindspore.numpy.minimum

逐元素比较两个数组,返回每对数组元素中的最小值。

Ascend GPU CPU

mindspore.numpy.multi_dot

计算两个或更多个数组的点积,同时自动选择最快的计算顺序。

Ascend GPU CPU

mindspore.numpy.multiply

参数逐元素相乘。

Ascend GPU CPU

mindspore.numpy.nancumsum

返回给定轴上数组元素的累积和,将NaN(非数值)视为零。

GPU CPU

mindspore.numpy.nanmax

返回数组的最大值或沿某个轴的最大值,忽略NaN。

GPU CPU

mindspore.numpy.nanmean

沿指定轴计算算术平均值,忽略NaN。

GPU CPU

mindspore.numpy.nanmin

返回数组的最大值或沿某个轴的最大值,忽略NaN。

GPU CPU

mindspore.numpy.nanstd

计算指定轴上元素的标准差,忽略NaN。

GPU CPU

mindspore.numpy.nansum

计算指定轴上元素的总和,将NaN(非数值)视为零。

GPU CPU

mindspore.numpy.nanvar

计算指定轴上的方差,忽略NaN。

GPU CPU

mindspore.numpy.negative

数值符号取反,逐元素操作。

Ascend GPU CPU

mindspore.numpy.norm

矩阵或向量的范数。

Ascend GPU CPU

mindspore.numpy.outer

计算两个向量的外积。

Ascend GPU CPU

mindspore.numpy.polyadd

找到两个多项式的和。

Ascend GPU CPU

mindspore.numpy.polyder

返回多项式指定阶数的导数。

Ascend GPU CPU

mindspore.numpy.polyint

返回多项式的一个反导数(不定积分)。

Ascend GPU CPU

mindspore.numpy.polymul

求两个多项式的积。

GPU

mindspore.numpy.polysub

两个多项式的差(减法)。

Ascend GPU CPU

mindspore.numpy.polyval

在特定值处求多项式的值。

Ascend GPU CPU

mindspore.numpy.positive

数值取正,逐元素计算。

Ascend GPU CPU

mindspore.numpy.power

第一个数组的元素以第二个数组的元素为指数逐元素求幂。

Ascend GPU CPU

mindspore.numpy.promote_types

返回 type1type2 都可以安全转换的最小位数和最小标量类型的数据类型。

Ascend GPU CPU

mindspore.numpy.ptp

沿某个轴的值范围(最大值 - 最小值)。

Ascend GPU CPU

mindspore.numpy.rad2deg

将角从弧度制转换为角度制。

Ascend GPU CPU

mindspore.numpy.radians

将角从角度制转换为弧度制。

Ascend GPU CPU

mindspore.numpy.ravel_multi_index

将元素为索引数组的tuple转换为展平的索引数组,并对多重索引应用边界模式。

GPU

mindspore.numpy.reciprocal

逐元素返回入参的倒数。

Ascend GPU CPU

mindspore.numpy.remainder

逐元素返回除法余数。

Ascend GPU CPU

mindspore.numpy.result_type

返回对入参使用类型提升规则所得的类型。

Ascend GPU CPU

mindspore.numpy.rint

将数组的元素四舍五入到最接近的整数。

Ascend GPU CPU

mindspore.numpy.searchsorted

找到每个元素的插入索引,使得插入后的数组保持原有升降序。

Ascend GPU CPU

mindspore.numpy.sign

逐元素返回数的符号。

Ascend GPU CPU

mindspore.numpy.sin

逐元素计算三角正弦函数。

Ascend GPU CPU

mindspore.numpy.sinh

逐元素计算双曲正弦函数。

Ascend CPU

mindspore.numpy.sqrt

逐元素返回数组的非负平方根。

Ascend GPU CPU

mindspore.numpy.square

逐元素返回输入的平方。

Ascend GPU CPU

mindspore.numpy.std

沿指定轴计算标准差。

Ascend GPU CPU

mindspore.numpy.subtract

逐元素减去给定参数。

Ascend GPU CPU

mindspore.numpy.sum

返回指定轴上数组元素的总和。

Ascend GPU CPU

mindspore.numpy.tan

逐元素计算正切值。

Ascend CPU

mindspore.numpy.tanh

逐元素计算双曲正切值。

Ascend GPU CPU

mindspore.numpy.tensordot

沿指定轴计算Tensor的点积。

Ascend GPU CPU

mindspore.numpy.trapz

使用复合梯形规则沿给定轴进行积分。

Ascend GPU CPU

mindspore.numpy.true_divide

返回输入的真除法,逐元素计算。

Ascend GPU CPU

mindspore.numpy.trunc

逐元素返回输入的截断值。

Ascend GPU CPU

mindspore.numpy.unwrap

通过加或减 2*pi ,改变数组相邻元素的差值实现解卷绕。

Ascend GPU CPU

mindspore.numpy.var

计算沿指定轴的方差。

Ascend GPU CPU

MindSpore Numpy与MindSpore特性结合

mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自动微分,并使用图模式加速运算,帮助用户快速构建高效的模型。同时,MindSpore还支持多种后端设备,包括Ascend、GPU和CPU等,用户可以根据自己的需求灵活设置。以下提供了几种常用方法:

  • jit 装饰器: 将代码包裹进图模式,用于提高代码运行效率。

  • GradOperation: 用于自动求导。

  • mindspore.set_context: 用于设置运行模式和后端设备等。

  • mindspore.nn.Cell: 用于建立深度学习模型。

使用示例如下:

  • jit 装饰器使用示例

    首先,以神经网络里经常使用到的矩阵乘与矩阵加算子为例:

    import mindspore.numpy as np
    
    x = np.arange(8).reshape(2, 4).astype('float32')
    w1 = np.ones((4, 8))
    b1 = np.zeros((8,))
    w2 = np.ones((8, 16))
    b2 = np.zeros((16,))
    w3 = np.ones((16, 4))
    b3 = np.zeros((4,))
    
    def forward(x, w1, b1, w2, b2, w3, b3):
        x = np.dot(x, w1) + b1
        x = np.dot(x, w2) + b2
        x = np.dot(x, w3) + b3
        return x
    
    print(forward(x, w1, b1, w2, b2, w3, b3))
    

    运行结果如下:

    [[ 768.  768.  768.  768.]
     [2816. 2816. 2816. 2816.]]
    

    对上述示例,我们可以借助 jit 装饰器将所有算子编译到一张静态图里以加快运行效率,示例如下:

    from mindspore import jit
    
    forward_compiled = jit(forward)
    print(forward(x, w1, b1, w2, b2, w3, b3))
    

    运行结果如下:

    [[ 768.  768.  768.  768.]
     [2816. 2816. 2816. 2816.]]
    

    说明

    目前静态图不支持在Python交互式模式下运行,并且有部分语法限制。

  • GradOperation使用示例

    GradOperation 可以实现自动求导。以下示例可以实现对上述没有用 jit 修饰的 forward 函数定义的计算求导。

    from mindspore import ops
    
    grad_all = ops.GradOperation(get_all=True)
    print(grad_all(forward)(x, w1, b1, w2, b2, w3, b3))
    

    运行结果如下:

    (Tensor(shape=[2, 4], dtype=Float32, value=
     [[ 5.12000000e+02,  5.12000000e+02,  5.12000000e+02,  5.12000000e+02],
      [ 5.12000000e+02,  5.12000000e+02,  5.12000000e+02,  5.12000000e+02]]),
     Tensor(shape=[4, 8], dtype=Float32, value=
     [[ 2.56000000e+02,  2.56000000e+02,  2.56000000e+02 ...  2.56000000e+02,  2.56000000e+02,  2.56000000e+02],
      [ 3.84000000e+02,  3.84000000e+02,  3.84000000e+02 ...  3.84000000e+02,  3.84000000e+02,  3.84000000e+02],
      [ 5.12000000e+02,  5.12000000e+02,  5.12000000e+02 ...  5.12000000e+02,  5.12000000e+02,  5.12000000e+02]
      [ 6.40000000e+02,  6.40000000e+02,  6.40000000e+02 ...  6.40000000e+02,  6.40000000e+02,  6.40000000e+02]]),
      ...
     Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00]))
    

    如果要对 jit 修饰的 forward 计算求导,需要提前使用 set_context 设置运算模式为图模式,示例如下:

    from mindspore import jit, set_context, GRAPH_MODE, ops
    
    set_context(mode=GRAPH_MODE)
    grad_all = ops.GradOperation(get_all=True)
    print(grad_all(jit(forward))(x, w1, b1, w2, b2, w3, b3))
    

    运行结果如下:

    (Tensor(shape=[2, 4], dtype=Float32, value=
     [[ 5.12000000e+02,  5.12000000e+02,  5.12000000e+02,  5.12000000e+02],
      [ 5.12000000e+02,  5.12000000e+02,  5.12000000e+02,  5.12000000e+02]]),
     Tensor(shape=[4, 8], dtype=Float32, value=
     [[ 2.56000000e+02,  2.56000000e+02,  2.56000000e+02 ...  2.56000000e+02,  2.56000000e+02,  2.56000000e+02],
      [ 3.84000000e+02,  3.84000000e+02,  3.84000000e+02 ...  3.84000000e+02,  3.84000000e+02,  3.84000000e+02],
      [ 5.12000000e+02,  5.12000000e+02,  5.12000000e+02 ...  5.12000000e+02,  5.12000000e+02,  5.12000000e+02]
      [ 6.40000000e+02,  6.40000000e+02,  6.40000000e+02 ...  6.40000000e+02,  6.40000000e+02,  6.40000000e+02]]),
      ...
     Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00]))
    

    更多细节可参考 API GradOperation

  • mindspore.set_context使用示例

    MindSpore支持多后端运算,可以通过 mindspore.set_context 进行设置。mindspore.numpy 的多数算子可以使用图模式或者PyNative模式运行,也可以运行在CPU,CPU或者Ascend等多种后端设备上。

    from mindspore import set_context, GRAPH_MODE, PYNATIVE_MODE
    
    # Execution in static graph mode
    set_context(mode=GRAPH_MODE)
    
    # Execution in PyNative mode
    set_context(mode=PYNATIVE_MODE)
    
    # Execution on CPU backend
    set_context(device_target="CPU")
    
    # Execution on GPU backend
    set_context(device_target="GPU")
    
    # Execution on Ascend backend
    set_context(device_target="Ascend")
    ...
    

    更多细节可参考 API mindspore.set_context

  • mindspore.numpy使用示例

    这里提供一个使用 mindspore.numpy 构建网络模型的示例。

    mindspore.numpy 接口可以定义在 nn.Cell 代码块内进行网络的构建,示例如下:

    import mindspore.numpy as np
    from mindspore import set_context, GRAPH_MODE
    from mindspore.nn import Cell
    
    set_context(mode=GRAPH_MODE)
    
    x = np.arange(8).reshape(2, 4).astype('float32')
    w1 = np.ones((4, 8))
    b1 = np.zeros((8,))
    w2 = np.ones((8, 16))
    b2 = np.zeros((16,))
    w3 = np.ones((16, 4))
    b3 = np.zeros((4,))
    
    class NeuralNetwork(Cell):
        def construct(self, x, w1, b1, w2, b2, w3, b3):
            x = np.dot(x, w1) + b1
            x = np.dot(x, w2) + b2
            x = np.dot(x, w3) + b3
            return x
    
    net = NeuralNetwork()
    
    print(net(x, w1, b1, w2, b2, w3, b3))
    

    运行结果如下:

    [[ 768.  768.  768.  768.]
     [2816. 2816. 2816. 2816.]]