mindspore.ops.jet
- mindspore.ops.jet(fn, primals, series)[source]
This function is designed to calculate the higher order differentiation of given composite function. To figure out first to n-th order differentiations, original inputs and first to n-th order derivative of original inputs must be provided together. Generally, it is recommended to set the values of given first order derivative to 1, while the other to 0, which is like the derivative of origin input with respect to itself.
Note
If primals is Tensor of int type, it will be converted to Tensor of float type.
- Parameters
fn (Union[Cell, function]) – Function to do TaylorOperation.
series (Union[Tensor, tuple[Tensor]]) – If tuple, the length and type of series should be the same as inputs. For each Tensor, the length of first dimension i represents the 1 to i+1-th order of derivative of output with respect to the inputs will be figured out.
- Returns
Tuple, tuple of out_primals and out_series.
out_primals (Union[Tensor, list[Tensor]]) - The output of fn(primals).
out_series (Union[Tensor, list[Tensor]]) - The 1 to i+1-th order of derivative of output with respect to the inputs.
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore.nn as nn >>> import mindspore as ms >>> import mindspore.ops as ops >>> from mindspore import Tensor >>> from mindspore.ops import jet >>> ms.set_context(mode=ms.GRAPH_MODE) >>> class Net(nn.Cell): ... def __init__(self): ... super().__init__() ... self.sin = ops.Sin() ... self.exp = ops.Exp() ... def construct(self, x): ... out1 = self.sin(x) ... out2 = self.exp(out1) ... return out2 >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32)) >>> net = Net() >>> out_primals, out_series = jet(net, primals, series) >>> print(out_primals, out_series) [[2.319777 2.4825778] [1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ] [-1.1400385 -0.3066662 ]] [[-1.2748207 -1.8274734 ] [ 0.966121 0.55551505]] [[-4.0515366 3.6724353 ] [ 0.5053504 -0.52061415]]]