mindspore.numpy
MindSpore NumPy工具包提供了一系列类NumPy接口。用户可以使用类NumPy语法在MindSpore上进行模型的搭建。
MindSpore Numpy具有四大功能模块:Array生成、Array操作、逻辑运算和数学运算。
在API示例中,常用的模块导入方法如下:
import mindspore.numpy as np
Note
MindSpore numpy通过组装底层算子来提供与numpy一致的编程体验接口,方便开发人员使用和代码移植。相比于MindSpore的function和ops接口,与原始numpy的接口格式及行为一致性更好,以便于用户理解和使用。注意:由于兼容numpy的考虑,部分接口的性能可能弱于function和ops接口。使用者可以按需选择不同类型的接口。
Array生成
生成类算子用来生成和构建具有指定数值、类型和形状的数组(Tensor)。
构建数组代码示例:
import mindspore.numpy as np
import mindspore.ops as ops
input_x = np.array([1, 2, 3], np.float32)
print("input_x =", input_x)
print("type of input_x =", ops.typeof(input_x))
运行结果如下:
input_x = [1. 2. 3.]
type of input_x = Tensor[Float32]
除了使用上述方法来创建外,也可以通过以下几种方式创建。
生成具有相同元素的数组
生成具有相同元素的数组代码示例:
input_x = np.full((2, 3), 6, np.float32) print(input_x)
运行结果如下:
[[6. 6. 6.] [6. 6. 6.]]
生成指定形状的全1数组,示例:
input_x = np.ones((2, 3), np.float32) print(input_x)
运行结果如下:
[[1. 1. 1.] [1. 1. 1.]]
生成具有某个范围内的数值的数组
生成指定范围内的等差数组代码示例:
input_x = np.arange(0, 5, 1) print(input_x)
运行结果如下:
[0 1 2 3 4]
生成特殊类型的数组
生成给定对角线处下方元素为1,上方元素为0的矩阵,示例:
input_x = np.tri(3, 3, 1) print(input_x)
运行结果如下:
[[1. 1. 0.] [1. 1. 1.] [1. 1. 1.]]
生成对角线为1,其他元素为0的二维矩阵,示例:
input_x = np.eye(2, 2) print(input_x)
运行结果如下:
[[1. 0.] [0. 1.]]
API Name |
Description |
Supported Platforms |
Returns evenly spaced values within a given interval. |
|
|
Creates a tensor. |
|
|
Converts the input to tensor. |
|
|
Similar to asarray, converts the input to a float tensor. |
|
|
Returns the Bartlett window. |
|
|
Returns the Blackman window. |
|
|
Returns a tensor copy of the given object. |
|
|
Extracts a diagonal or construct a diagonal array. |
|
|
Returns the indices to access the main diagonal of an array. |
|
|
Creates a two-dimensional array with the flattened input as a diagonal. |
|
|
Returns specified diagonals. |
|
|
Returns a new array of given shape and type, without initializing entries. |
|
|
Returns a new array with the same shape and type as a given array. |
|
|
Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. |
|
|
Returns a new tensor of given shape and type, filled with fill_value. |
|
|
Returns a full array with the same shape and type as a given array. |
|
|
Returns numbers spaced evenly on a log scale (a geometric progression). |
|
|
Returns the Hamming window. |
|
|
Returns the Hanning window. |
|
|
Function to calculate only the edges of the bins used by the histogram function. |
|
|
Returns the identity tensor. |
|
|
Returns an array representing the indices of a grid. |
|
|
Constructs an open mesh from multiple sequences. |
|
|
Returns evenly spaced values within a given interval. |
|
|
Returns numbers spaced evenly on a log scale. |
|
|
Returns coordinate matrices from coordinate vectors. |
|
|
mgrid is an |
|
|
ogrid is an |
|
|
Returns a new tensor of given shape and type, filled with ones. |
|
|
Returns an array of ones with the same shape and type as a given array. |
|
|
Pads an array. |
|
|
Returns a new Tensor with given shape and dtype, filled with random numbers from the uniform distribution on the interval \([0, 1)\). |
|
|
Return random integers from minval (inclusive) to maxval (exclusive). |
|
|
Returns a new Tensor with given shape and dtype, filled with a sample (or samples) from the standard normal distribution. |
|
|
Returns the sum along diagonals of the array. |
|
|
Returns a tensor with ones at and below the given diagonal and zeros elsewhere. |
|
|
Returns a lower triangle of a tensor. |
|
|
Returns the indices for the lower-triangle of an (n, m) array. |
|
|
Returns the indices for the lower-triangle of arr. |
|
|
Returns an upper triangle of a tensor. |
|
|
Returns the indices for the upper-triangle of an (n, m) array. |
|
|
Returns the indices for the upper-triangle of arr. |
|
|
Generates a Vandermonde matrix. |
|
|
Returns a new tensor of given shape and type, filled with zeros. |
|
|
Returns an array of zeros with the same shape and type as a given array. |
|
Array操作
操作类算子主要进行数组的维度变换,分割和拼接等。
数组维度变换
矩阵转置,代码示例:
input_x = np.arange(10).reshape(5, 2) output = np.transpose(input_x) print(output)
运行结果如下:
[[0 2 4 6 8] [1 3 5 7 9]]
交换指定轴,代码示例:
input_x = np.ones((1, 2, 3)) output = np.swapaxes(input_x, 0, 1) print(output.shape)
运行结果如下:
(2, 1, 3)
数组分割
将输入数组平均切分为多个数组,代码示例:
input_x = np.arange(9) output = np.split(input_x, 3) print(output)
运行结果如下:
(Tensor(shape=[3], dtype=Int32, value= [0, 1, 2]), Tensor(shape=[3], dtype=Int32, value= [3, 4, 5]), Tensor(shape=[3], dtype=Int32, value= [6, 7, 8]))
数组拼接
将两个数组按照指定轴进行拼接,代码示例:
input_x = np.arange(0, 5) input_y = np.arange(10, 15) output = np.concatenate((input_x, input_y), axis=0) print(output)
运行结果如下:
[ 0 1 2 3 4 10 11 12 13 14]
API Name |
Description |
Supported Platforms |
Appends values to the end of a tensor. |
|
|
Applies a function to 1-D slices along the given axis. |
|
|
Applies a function repeatedly over multiple axes. |
|
|
Splits a tensor into multiple sub-tensors. |
|
|
Returns a string representation of the data in an array. |
|
|
Converts inputs to arrays with at least one dimension. |
|
|
Reshapes inputs as arrays with at least two dimensions. |
|
|
Reshapes inputs as arrays with at least three dimensions. |
|
|
Broadcasts any number of arrays against each other. |
|
|
Broadcasts an array to a new shape. |
|
|
Construct an array from an index array and a list of arrays to choose from. |
|
|
Stacks 1-D tensors as columns into a 2-D tensor. |
|
|
Joins a sequence of tensors along an existing axis. |
|
|
Splits a tensor into multiple sub-tensors along the 3rd axis (depth). |
|
|
Stacks tensors in sequence depth wise (along the third axis). |
|
|
Expands the shape of a tensor. |
|
|
Reverses the order of elements in an array along the given axis. |
|
|
Flips the entries in each row in the left/right direction. |
|
|
Flips the entries in each column in the up/down direction. |
|
|
Splits a tensor into multiple sub-tensors horizontally (column-wise). |
|
|
Stacks tensors in sequence horizontally. |
|
|
Moves axes of an array to new positions. |
|
|
Evaluates a piecewise-defined function. |
|
|
Returns a contiguous flattened tensor. |
|
|
Repeats elements of an array. |
|
|
Reshapes a tensor without changing its data. |
|
|
Rolls a tensor along given axes. |
|
|
Rolls the specified axis backwards, until it lies in the given position. |
|
|
Rotates a tensor by 90 degrees in the plane specified by axes. |
|
|
Returns an array drawn from elements in choicelist, depending on conditions. |
|
|
Returns the number of elements along a given axis. |
|
|
Splits a tensor into multiple sub-tensors along the given axis. |
|
|
Removes single-dimensional entries from the shape of a tensor. |
|
|
Joins a sequence of arrays along a new axis. |
|
|
Interchanges two axes of a tensor. |
|
|
Takes elements from an array along an axis. |
|
|
Takes values from the input array by matching 1d index and data slices. |
|
|
Constructs an array by repeating a the number of times given by reps. |
|
|
Reverses or permutes the axes of a tensor; returns the modified tensor. |
|
|
Finds the unique elements of a tensor. |
|
|
Converts a flat index or array of flat indices into a tuple of coordinate arrays. |
|
|
Splits a tensor into multiple sub-tensors vertically (row-wise). |
|
|
Stacks tensors in sequence vertically. |
|
|
Returns elements chosen from x or y depending on condition. |
|
逻辑运算
逻辑运算类算子主要进行各类逻辑相关的运算。
相等(equal)和小于(less)计算代码示例如下:
input_x = np.arange(0, 5)
input_y = np.arange(0, 10, 2)
output = np.equal(input_x, input_y)
print("output of equal:", output)
output = np.less(input_x, input_y)
print("output of less:", output)
运行结果如下:
output of equal: [ True False False False False]
output of less: [False True True True True]
API Name |
Description |
Supported Platforms |
Returns True if input arrays have same shapes and all elements equal. |
|
|
Returns True if input arrays are shape consistent and all elements equal. |
|
|
Returns the truth value of |
|
|
Returns the truth value of |
|
|
Returns the truth value of |
|
|
Tests whether each element of a 1-D array is also present in a second array. |
|
|
Returns a boolean tensor where two tensors are element-wise equal within a tolerance. |
|
|
Tests element-wise for finiteness (not infinity or not Not a Number). |
|
|
Calculates element in test_elements, broadcasting over element only. |
|
|
Tests element-wise for positive or negative infinity. |
|
|
Tests element-wise for NaN and return result as a boolean array. |
|
|
Tests element-wise for negative infinity, returns result as bool array. |
|
|
Tests element-wise for positive infinity, returns result as bool array. |
|
|
Returns True if the type of element is a scalar type. |
|
|
Returns the truth value of |
|
|
Returns the truth value of |
|
|
Computes the truth value of x1 AND x2 element-wise. |
|
|
Computes the truth value of NOT a element-wise. |
|
|
Computes the truth value of x1 OR x2 element-wise. |
|
|
Computes the truth value of x1 XOR x2, element-wise. |
|
|
Returns (x1 != x2) element-wise. |
|
|
Returns element-wise True where signbit is set (less than zero). |
|
|
Tests whether any array element along a given axis evaluates to True. |
|
数学运算
数学运算类算子包括各类数学相关的运算:加减乘除乘方,以及指数、对数等常见函数等。
数学计算支持类似NumPy的广播特性。
加法
以下代码实现了 input_x 和 input_y 两数组相加的操作:
input_x = np.full((3, 2), [1, 2]) input_y = np.full((3, 2), [3, 4]) output = np.add(input_x, input_y) print(output)
运行结果如下:
[[4 6] [4 6] [4 6]]
矩阵乘法
以下代码实现了 input_x 和 input_y 两矩阵相乘的操作:
input_x = np.arange(2*3).reshape(2, 3).astype('float32') input_y = np.arange(3*4).reshape(3, 4).astype('float32') output = np.matmul(input_x, input_y) print(output)
运行结果如下:
[[20. 23. 26. 29.] [56. 68. 80. 92.]]
求平均值
以下代码实现了求 input_x 所有元素的平均值的操作:
input_x = np.arange(6).astype('float32') output = np.mean(input_x) print(output)
运行结果如下:
2.5
指数
以下代码实现了自然常数 e 的 input_x 次方的操作:
input_x = np.arange(5).astype('float32') output = np.exp(input_x) print(output)
运行结果如下:
[ 1. 2.7182817 7.389056 20.085537 54.59815 ]
API Name |
Description |
Supported Platforms |
Calculates the absolute value element-wise. |
|
|
Adds arguments element-wise. |
|
|
Returns the maximum of an array or maximum along an axis. |
|
|
Returns the minimum of an array or minimum along an axis. |
|
|
Trigonometric inverse cosine, element-wise. |
|
|
Inverse hyperbolic cosine, element-wise. |
|
|
Inverse sine, element-wise. |
|
|
Inverse hyperbolic sine element-wise. |
|
|
Trigonometric inverse tangent, element-wise. |
|
|
Element-wise arc tangent of \(x1/x2\) choosing the quadrant correctly. |
|
|
Inverse hyperbolic tangent element-wise. |
|
|
Returns the indices of the maximum values along an axis. |
|
|
Returns the indices of the minimum values along an axis. |
|
|
Evenly round to the given number of decimals. |
|
|
Computes the weighted average along the specified axis. |
|
|
Count number of occurrences of each value in array of non-negative ints. |
|
|
Computes the bit-wise AND of two arrays element-wise. |
|
|
Computes the bit-wise OR of two arrays element-wise. |
|
|
Computes the bit-wise XOR of two arrays element-wise. |
|
|
Returns the cube-root of a tensor, element-wise. |
|
|
Returns the ceiling of the input, element-wise. |
|
|
Clips (limits) the values in an array. |
|
|
Returns the discrete, linear convolution of two one-dimensional sequences. |
|
|
Changes the sign of x1 to that of x2, element-wise. |
|
|
Returns Pearson product-moment correlation coefficients. |
|
|
Cross-correlation of two 1-dimensional sequences. |
|
|
Cosine element-wise. |
|
|
Hyperbolic cosine, element-wise. |
|
|
Counts the number of non-zero values in the tensor x. |
|
|
Estimates a covariance matrix, given data and weights. |
|
|
Returns the cross product of two (arrays of) vectors. |
|
|
Returns the cumulative product of elements along a given axis. |
|
|
Returns the cumulative sum of the elements along a given axis. |
|
|
Converts angles from degrees to radians. |
|
|
Calculates the n-th discrete difference along the given axis. |
|
|
Returns the indices of the bins to which each value in input array belongs. |
|
|
Returns a true division of the inputs, element-wise. |
|
|
Returns element-wise quotient and remainder simultaneously. |
|
|
Returns the dot product of two arrays. |
|
|
The differences between consecutive elements of a tensor. |
|
|
Calculates the exponential of all elements in the input array. |
|
|
Calculates |
|
|
Calculates |
|
|
Rounds to nearest integer towards zero. |
|
|
First array elements raised to powers from second array, element-wise. |
|
|
Returns the floor of the input, element-wise. |
|
|
Returns the largest integer smaller or equal to the division of the inputs. |
|
|
Returns the element-wise remainder of division. |
|
|
Returns the greatest common divisor of |
|
|
Returns the gradient of a N-dimensional array. |
|
|
Computes the Heaviside step function. |
|
|
Computes the histogram of a dataset. |
|
|
Computes the multidimensional histogram of some data. |
|
|
Computes the multidimensional histogram of some data. |
|
|
Given the "legs" of a right triangle, returns its hypotenuse. |
|
|
Returns the inner product of two tensors. |
|
|
One-dimensional linear interpolation for monotonically increasing sample points. |
|
|
Computes bit-wise inversion, or bit-wise NOT, element-wise. |
|
|
Kronecker product of two arrays. |
|
|
Returns the lowest common multiple of |
|
|
Returns the natural logarithm, element-wise. |
|
|
Base-10 logarithm of x. |
|
|
Returns the natural logarithm of one plus the input array, element-wise. |
|
|
Base-2 logarithm of x. |
|
|
Logarithm of the sum of exponentiations of the inputs. |
|
|
Logarithm of the sum of exponentiations of the inputs in base of 2. |
|
|
Returns the matrix product of two arrays. |
|
|
Raises a square matrix to the (integer) power n. |
|
|
Returns the element-wise maximum of array elements. |
|
|
Computes the arithmetic mean along the specified axis. |
|
|
Element-wise minimum of tensor elements. |
|
|
Computes the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order. |
|
|
Multiplies arguments element-wise. |
|
|
Return the cumulative sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. |
|
|
Return the maximum of an array or maximum along an axis, ignoring any NaNs. |
|
|
Computes the arithmetic mean along the specified axis, ignoring NaNs. |
|
|
Returns the minimum of array elements over a given axis, ignoring any NaNs. |
|
|
Computes the standard deviation along the specified axis, while ignoring NaNs. |
|
|
Returns the sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. |
|
|
Computes the variance along the specified axis, while ignoring NaNs. |
|
|
Numerical negative, element-wise. |
|
|
Matrix or vector norm. |
|
|
Computes the outer product of two vectors. |
|
|
Finds the sum of two polynomials. |
|
|
Returns the derivative of the specified order of a polynomial. |
|
|
Returns an antiderivative (indefinite integral) of a polynomial. |
|
|
Finds the product of two polynomials. |
|
|
Difference (subtraction) of two polynomials. |
|
|
Evaluates a polynomial at specific values. |
|
|
Numerical positive, element-wise. |
|
|
First array elements raised to powers from second array, element-wise. |
|
|
Returns the data type with the smallest size and smallest scalar kind. |
|
|
Range of values (maximum - minimum) along an axis. |
|
|
Converts angles from radians to degrees. |
|
|
Converts angles from degrees to radians. |
|
|
Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index. |
|
|
Returns the reciprocal of the argument, element-wise. |
|
|
Returns element-wise remainder of division. |
|
|
Returns the type that results from applying the type promotion rules to the arguments. |
|
|
Rounds elements of the array to the nearest integer. |
|
|
Finds indices where elements should be inserted to maintain order. |
|
|
Returns an element-wise indication of the sign of a number. |
|
|
Trigonometric sine, element-wise. |
|
|
Hyperbolic sine, element-wise. |
|
|
Returns the non-negative square-root of an array, element-wise. |
|
|
Returns the element-wise square of the input. |
|
|
Computes the standard deviation along the specified axis. |
|
|
Subtracts arguments, element-wise. |
|
|
Returns sum of array elements over a given axis. |
|
|
Computes tangent element-wise. |
|
|
Computes hyperbolic tangent element-wise. |
|
|
Computes tensor dot product along specified axes. |
|
|
Integrates along the given axis using the composite trapezoidal rule. |
|
|
Returns a true division of the inputs, element-wise. |
|
|
Returns the truncated value of the input, element-wise. |
|
|
Unwraps by changing deltas between values to |
|
|
Computes the variance along the specified axis. |
|
MindSpore Numpy与MindSpore特性结合
mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自动微分,并使用图模式加速运算,帮助用户快速构建高效的模型。同时,MindSpore还支持多种后端设备,包括Ascend、GPU和CPU等,用户可以根据自己的需求灵活设置。以下提供了几种常用方法:
ms_function: 将代码包裹进图模式,用于提高代码运行效率。
GradOperation: 用于自动求导。
mindspore.set_context: 用于设置运行模式和后端设备等。
mindspore.nn.Cell: 用于建立深度学习模型。
使用示例如下:
ms_function使用示例
首先,以神经网络里经常使用到的矩阵乘与矩阵加算子为例:
import mindspore.numpy as np x = np.arange(8).reshape(2, 4).astype('float32') w1 = np.ones((4, 8)) b1 = np.zeros((8,)) w2 = np.ones((8, 16)) b2 = np.zeros((16,)) w3 = np.ones((16, 4)) b3 = np.zeros((4,)) def forward(x, w1, b1, w2, b2, w3, b3): x = np.dot(x, w1) + b1 x = np.dot(x, w2) + b2 x = np.dot(x, w3) + b3 return x print(forward(x, w1, b1, w2, b2, w3, b3))
运行结果如下:
[[ 768. 768. 768. 768.] [2816. 2816. 2816. 2816.]]
对上述示例,我们可以借助 ms_function 将所有算子编译到一张静态图里以加快运行效率,示例如下:
from mindspore import ms_function forward_compiled = ms_function(forward) print(forward(x, w1, b1, w2, b2, w3, b3))
运行结果如下:
[[ 768. 768. 768. 768.] [2816. 2816. 2816. 2816.]]
Note
目前静态图不支持在Python交互式模式下运行,并且有部分语法限制。ms_function 的更多信息可参考 API ms_function 。
GradOperation使用示例
GradOperation 可以实现自动求导。以下示例可以实现对上述没有用 ms_function 修饰的 forward 函数定义的计算求导。
from mindspore import ops grad_all = ops.composite.GradOperation(get_all=True) print(grad_all(forward)(x, w1, b1, w2, b2, w3, b3))
运行结果如下:
(Tensor(shape=[2, 4], dtype=Float32, value= [[ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02]]), Tensor(shape=[4, 8], dtype=Float32, value= [[ 2.56000000e+02, 2.56000000e+02, 2.56000000e+02 ... 2.56000000e+02, 2.56000000e+02, 2.56000000e+02], [ 3.84000000e+02, 3.84000000e+02, 3.84000000e+02 ... 3.84000000e+02, 3.84000000e+02, 3.84000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02 ... 5.12000000e+02, 5.12000000e+02, 5.12000000e+02] [ 6.40000000e+02, 6.40000000e+02, 6.40000000e+02 ... 6.40000000e+02, 6.40000000e+02, 6.40000000e+02]]), ... Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]))
如果要对 ms_function 修饰的 forward 计算求导,需要提前使用 set_context 设置运算模式为图模式,示例如下:
from mindspore import ms_function, set_context, GRAPH_MODE set_context(mode=GRAPH_MODE) grad_all = ops.composite.GradOperation(get_all=True) print(grad_all(ms_function(forward))(x, w1, b1, w2, b2, w3, b3))
运行结果如下:
(Tensor(shape=[2, 4], dtype=Float32, value= [[ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02]]), Tensor(shape=[4, 8], dtype=Float32, value= [[ 2.56000000e+02, 2.56000000e+02, 2.56000000e+02 ... 2.56000000e+02, 2.56000000e+02, 2.56000000e+02], [ 3.84000000e+02, 3.84000000e+02, 3.84000000e+02 ... 3.84000000e+02, 3.84000000e+02, 3.84000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02 ... 5.12000000e+02, 5.12000000e+02, 5.12000000e+02] [ 6.40000000e+02, 6.40000000e+02, 6.40000000e+02 ... 6.40000000e+02, 6.40000000e+02, 6.40000000e+02]]), ... Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]))
更多细节可参考 API GradOperation 。
mindspore.set_context使用示例
MindSpore支持多后端运算,可以通过 mindspore.set_context 进行设置。mindspore.numpy 的多数算子可以使用图模式或者PyNative模式运行,也可以运行在CPU,CPU或者Ascend等多种后端设备上。
from mindspore import set_context, GRAPH_MODE, PYNATIVE_MODE # Execucation in static graph mode set_context(mode=GRAPH_MODE) # Execucation in PyNative mode set_context(mode=PYNATIVE_MODE) # Execucation on CPU backend set_context(device_target="CPU") # Execucation on GPU backend set_context(device_target="GPU") # Execucation on Ascend backend set_context(device_target="Ascend") ...
更多细节可参考 API mindspore.set_context 。
mindspore.numpy使用示例
这里提供一个使用 mindspore.numpy 构建网络模型的示例。
mindspore.numpy 接口可以定义在 nn.Cell 代码块内进行网络的构建,示例如下:
import mindspore.numpy as np from mindspore import set_context, GRAPH_MODE from mindspore.nn import Cell set_context(mode=GRAPH_MODE) x = np.arange(8).reshape(2, 4).astype('float32') w1 = np.ones((4, 8)) b1 = np.zeros((8,)) w2 = np.ones((8, 16)) b2 = np.zeros((16,)) w3 = np.ones((16, 4)) b3 = np.zeros((4,)) class NeuralNetwork(Cell): def construct(self, x, w1, b1, w2, b2, w3, b3): x = np.dot(x, w1) + b1 x = np.dot(x, w2) + b2 x = np.dot(x, w3) + b3 return x net = NeuralNetwork() print(net(x, w1, b1, w2, b2, w3, b3))
运行结果如下:
[[ 768. 768. 768. 768.] [2816. 2816. 2816. 2816.]]