mindspore.ops.GradOperation

View Source On Gitee
class mindspore.ops.GradOperation(get_all=False, get_by_list=False, sens_param=False)[source]

A higher-order function which is used to generate the gradient function for the input function.

The gradient function generated by GradOperation higher-order function can be customized by construction arguments.

For example, given an input function net = Net() that takes x and y as inputs, and has a parameter z, see Net in Examples.

  • Used to get the derivative of the input:

    1. Returns gradients with respect to the first input (see GradNetWrtX in Examples).

      1. Construct a GradOperation higher-order function with default arguments: grad_op = GradOperation().

      2. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

      3. Call the gradient function with input function’s inputs to get the gradients with respect to the first input: grad_op(net)(x, y).

    2. Returns gradients with respect to all inputs (see GradNetWrtXY in Examples).

      1. Construct a GradOperation higher-order function with get_all=True which indicates getting gradients with respect to all inputs, they are x and y in example function Net(): grad_op = GradOperation(get_all=True).

      2. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

      3. Call the gradient function with input function’s inputs to get the gradients with respect to all inputs: gradient_function(x, y).

  • Used to get the derivative of the parameters:

    Returns gradients with respect to given parameters (see GradNetWithWrtParams in Examples).

    1. Construct a GradOperation higher-order function with get_by_list=True: grad_op = GradOperation(get_by_list=True).

    2. Construct a ParameterTuple that will be passed to the input function when constructing GradOperation higher-order function, it will be used as a parameter filter that determine which gradient to return: params = ParameterTuple(net.trainable_params()).

    3. Call it with input function and params as arguments to get the gradient function: gradient_function = grad_op(net, params).

    4. Call the gradient function with input function’s inputs to get the gradients with respect to given parameters: gradient_function(x, y).

  • Used to get the derivative of the inputs and parameters at the same time: Returns gradients with respect to all inputs and given parameters in the format of ((dx, dy), (dz)) (see GradNetWrtInputsAndParams in Examples).

    1. Construct a GradOperation higher-order function with get_all=True and get_by_list=True: grad_op = GradOperation(get_all=True, get_by_list=True).

    2. Construct a ParameterTuple that will be passed along input function when constructing GradOperation higher-order function: params = ParameterTuple(net.trainable_params()).

    3. Call it with input function and params as arguments to get the gradient function: gradient_function = grad_op(net, params).

    4. Call the gradient function with input function’s inputs to get the gradients with respect to all inputs and given parameters: gradient_function(x, y).

  • We can configure the sensitivity(gradient with respect to output) by setting sens_param as True and passing an extra sensitivity input to the gradient function, the sensitivity input should has the same shape and type with input function’s output(see GradNetWrtXYWithSensParam in Examples).

    1. Construct a GradOperation higher-order function with get_all=True and sens_param=True: grad_op = GradOperation(get_all=True, sens_param=True).

    2. Define grad_wrt_output as sens_param which works as the gradient with respect to output: grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32)).

    3. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

    4. Call the gradient function with input function’s inputs and sens_param to get the gradients with respect to all inputs: gradient_function(x, y, grad_wrt_output).

Note

For above gradient functions, the returned gradient result may vary for grad result element number:

  • Return a single value if only one result.

  • Return a tuple for multiple results.

  • Return an empty tuple for no result.

Parameters
  • get_all (bool) – If True , get all the gradients with respect to inputs. Default: False .

  • get_by_list (bool) – If True , get all the gradients with respect to Parameter free variables. If get_all and get_by_list are both False , get the gradient with respect to first input. If get_all and get_by_list are both True , get the gradients with respect to inputs and Parameter free variables at the same time in the form of (“gradients with respect to inputs”, “gradients with respect to parameter free variables”). Default: False .

  • sens_param (bool) – Whether to append sensitivity (gradient with respect to output) as input. If sens_param is False , a ‘ones_like(outputs)’ sensitivity will be attached automatically. Default: False . If the sensor_param is True , a sensitivity (gradient with respect to output) needs to be transferred through the location parameter or key-value pair parameter. If the value is transferred through the key-value pair parameter, the key must be sens.

Returns

The higher-order function which takes a function as argument and returns gradient function for it.

Raises

TypeError – If get_all, get_by_list or sens_param is not a bool.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor, ops, nn, Parameter
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.matmul = ops.MatMul()
...         self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
...     def construct(self, x, y):
...         x = x * self.z
...         out = self.matmul(x, y)
...         return out
...
>>> class GradNetWrtX(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtX, self).__init__()
...         self.net = net
...         self.grad_op = ops.GradOperation()
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(x, y)
...
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> output = GradNetWrtX(Net())(x, y)
>>> print(output)
[[1.4100001 1.5999999 6.6      ]
 [1.4100001 1.5999999 6.6      ]]
>>>
>>> class GradNetWrtXY(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtXY, self).__init__()
...         self.net = net
...         self.grad_op = ops.GradOperation(get_all=True)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.1, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWrtXY(Net())(x, y)
>>> print(output)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 4.50000000e+00,  2.70000005e+00,  3.60000014e+00],
 [ 4.50000000e+00,  2.70000005e+00,  3.60000014e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 2.59999990e+00,  2.59999990e+00,  2.59999990e+00],
 [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
 [ 1.30000007e+00,  1.30000007e+00,  1.30000007e+00]]))
>>>
>>> class GradNetWrtXYWithSensParam(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtXYWithSensParam, self).__init__()
...         self.net = net
...         self.grad_op = ops.GradOperation(get_all=True, sens_param=True)
...         self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(x, y, self.grad_wrt_output)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWrtXYWithSensParam(Net())(x, y)
>>> print(output)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 2.21099997e+00,  5.09999990e-01,  1.49000001e+00],
 [ 5.58800030e+00,  2.68000007e+00,  4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.51999998e+00,  2.81999993e+00,  2.14000010e+00],
 [ 1.09999990e+00,  2.04999995e+00,  1.54999995e+00],
 [ 9.00000036e-01,  1.54999995e+00,  1.25000000e+00]]))
>>>
>>> class GradNetWithWrtParams(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWithWrtParams, self).__init__()
...         self.net = net
...         self.params = ParameterTuple(net.trainable_params())
...         self.grad_op = ops.GradOperation(get_by_list=True)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net, self.params)
...         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWithWrtParams(Net())(x, y)
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)
>>>
>>> class GradNetWrtInputsAndParams(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtInputsAndParams, self).__init__()
...         self.net = net
...         self.params = ParameterTuple(net.trainable_params())
...         self.grad_op = ops.GradOperation(get_all=True, get_by_list=True)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net, self.params)
...         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
>>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWrtInputsAndParams(Net())(x, y)
>>> print(output)
((Tensor(shape=[2, 3], dtype=Float32, value=
[[ 3.51999998e+00,  3.90000010e+00,  2.59999990e+00],
 [ 3.51999998e+00,  3.90000010e+00,  2.59999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 6.00000024e-01,  6.00000024e-01,  6.00000024e-01],
 [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
 [ 1.30000007e+00,  1.30000007e+00,  1.30000007e+00]])), (Tensor(shape=[1], dtype=Float32, value=
 [ 1.29020004e+01]),))