mindspore.ops.derivative

查看源文件
mindspore.ops.derivative(fn, primals, order)[源代码]

计算函数或网络输出对输入的高阶微分。给定待求导函数的原始输入和求导的阶数n,将返回函数输出对输入的第n阶导数。输入的初始1阶导数在内部默认设置为1,其他阶设置为0。

说明

  • primals 是int型的Tensor,会被转化成float32格式进行计算。

参数:
  • fn (Union[Cell, function]) - 待求导的函数或网络。

  • primals (Union[Tensor, tuple[Tensor]]) - fn 的输入,单输入的type为Tensor,多输入的type为Tensor组成的tuple。

  • order (int) - 求导的阶数。

返回:

tuple,由 out_primalsout_series 组成。

  • out_primals (Union[Tensor, list[Tensor]]) - fn(primals) 的结果。

  • out_series (Union[Tensor, list[Tensor]]) - fn 输出对输入的第n阶导数。

异常:
  • TypeError - primals 不是Tensor或tuple。

  • TypeError - order 不是int。

  • ValueError - order 不是正数。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> class Net(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.sin = ops.Sin()
...         self.exp = ops.Exp()
...     def construct(self, x):
...         out1 = self.sin(x)
...         out2 = self.exp(out1)
...         return out2
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> order = 3
>>> net = Net()
>>> out_primals, out_series = ops.derivative(net, primals, order)
>>> print(out_primals, out_series)
[[2.319777  2.4825778]
 [1.1515628 0.4691642]] [[-4.0515366   3.6724353 ]
 [ 0.5053504  -0.52061415]]