mindspore.ops.derivative

查看源文件
mindspore.ops.derivative(fn, primals, order)[源代码]

计算函数或网络输出对输入的高阶微分。给定待求导函数的原始输入和求导的阶数n,将返回函数输出对输入的第n阶导数。输入的初始1阶导数在内部默认设置为1,其他阶设置为0。

说明

  • primals 是int型的tensor,会被转化成float32格式进行计算。

参数:
  • fn (Union[Cell, function]) - 待求导的函数或网络。

  • primals (Union[Tensor, tuple[Tensor]]) - fn 的输入。

  • order (int) - 求导的阶数。

返回:

tuple(out_primals, out_series)

  • out_primals (Union[Tensor, list[Tensor]]) - fn(primals) 的结果。

  • out_series (Union[Tensor, list[Tensor]]) - fn 输出对输入的第n阶导数。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import nn
>>> mindspore.set_context(mode=mindspore.GRAPH_MODE)
>>> class Net(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.sin = mindspore.ops.Sin()
...         self.exp = mindspore.ops.Exp()
...     def construct(self, x):
...         out1 = self.sin(x)
...         out2 = self.exp(out1)
...         return out2
>>>
>>> primals = mindspore.tensor([[1, 2], [3, 4]], mindspore.float32)
>>> order = 3
>>> net = Net()
>>> out_primals, out_series = mindspore.ops.derivative(net, primals, order)
>>> print(out_primals, out_series)
[[2.319777  2.4825778]
 [1.1515628 0.4691642]] [[-4.0515366   3.6724353 ]
 [ 0.5053504  -0.52061415]]