mindspore.ops.jet

View Source On Gitee
mindspore.ops.jet(fn, primals, series)[source]

This function is designed to calculate the higher order differentiation of given composite function. To figure out first to n-th order differentiations, original inputs and first to n-th order derivative of original inputs must be provided together. Generally, it is recommended to set the values of given first order derivative to 1, while the other to 0, which is like the derivative of origin input with respect to itself.

Note

If primals is Tensor of int type, it will be converted to Tensor of float type.

Parameters
  • fn (Union[Cell, function]) – Function to do TaylorOperation.

  • primals (Union[Tensor, tuple[Tensor]]) – The inputs to fn.

  • series (Union[Tensor, tuple[Tensor]]) – If tuple, the length and type of series should be the same as inputs. For each Tensor, the length of first dimension i represents the 1 to i+1-th order of derivative of output with respect to the inputs will be figured out.

Returns

Tuple, tuple of out_primals and out_series.

  • out_primals (Union[Tensor, list[Tensor]]) - The output of fn(primals).

  • out_series (Union[Tensor, list[Tensor]]) - The 1 to i+1-th order of derivative of output with respect to the inputs.

Raises
  • TypeError – If primals is not a tensor or tuple of tensors.

  • TypeError – If type of primals is not the same as type of series.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> class Net(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.sin = ops.Sin()
...         self.exp = ops.Exp()
...     def construct(self, x):
...         out1 = self.sin(x)
...         out2 = self.exp(out1)
...         return out2
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
>>> net = Net()
>>> out_primals, out_series = ops.jet(net, primals, series)
>>> print(out_primals, out_series)
[[2.319777  2.4825778]
 [1.1515628 0.4691642]] [[[ 1.2533808  -1.0331168 ]
  [-1.1400385  -0.3066662 ]]
 [[-1.2748207  -1.8274734 ]
  [ 0.966121    0.55551505]]
 [[-4.0515366   3.6724353 ]
  [ 0.5053504  -0.52061415]]]