Advance Usage of Custom Operators
Registering the Operator Information
The operator information describes the supported inputs and outputs data type, the supported inputs and outputs format, attributes, and target (platform information) of the operator implementation. It is used to select and map operators by the backend. The operator information can be defined by using the CustomRegOp API, then you can use the custom_info_register decorator or just pass it to the reg_info
parameter of Custom primitive to bind the information to the operator implementation. The operator information will be registered to the operator information library on the MindSpore C++ side at last. The reg_info
parameter takes higher priority than the custom_info_register
decorator.
The target value in operator information can be "Ascend", "GPU" or "CPU", which describes the operator information on a specific target. For the same operator implementation, it may have different supported data types on different targets, so you can use the target value in operator information to differ this. The operator information on a specific target will be registered only once.
The numbers and sequences of the input and output information defined in the operator information must be the same as those in the parameters of the operator implementation.
For the custom operator of akg type, if the operator has attributes, you need to register operator information. The attribute name in the operator information must be consistent with the attribute name used in the operator implementation. For the custom operator of tbe type, you need to register operator information. For the custom operator of aot type, since the operator implementation needs to be compiled into a dynamic library in advance, it is not possible to bind operator information by means of decorators, and the operator information can only be passed in through the
reg_info
parameter.If the custom operator only supports a specific input and output data type or data format, the operator information needs to be registered so that the data type and data format can be checked when the operator is selected in the backend. For the case where the operator information is not provided, the information will be derived from the inputs of the current operator.
Defining the bprop Function for Operators
If an operator needs to support automatic differentiation, the backpropagation (bprop) function needs to be defined first and then passed to the bprop
parameter of Custom
primitive. In the bprop function, you need to describe the backward computation logic that uses the forward input, forward output, and output gradients to obtain the input gradients. The backward computation logic can be composed of built-in operators or custom operators.
Note the following points when defining the bprop function for operators:
The input parameter sequence of the bprop function is the forward input, forward output, and output gradients. For a multi-output operator, the forward output and output gradients are provided in the form of tuples.
The return value of the bprop function is tuples consisting of input gradients. The sequence of elements in a tuple is the same as that of the forward input parameters. Even if there is only one input gradient, the return value must be a tuple.
Take test_grad.py as an example to show the usage of the backpropagation function:
import numpy as np
import mindspore as ms
from mindspore.nn import Cell
import mindspore.ops as ops
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
# Forward computation of custom operator
def square(x):
return x * x
# Backward computation of custom operator
def square_grad(x, dout):
dx = 2.0 * x * dout
return dx
# Backpropagation function
def bprop():
op = ops.Custom(square_grad, lambda x, _: x, lambda x, _: x, func_type="pyfunc")
def custom_bprop(x, out, dout):
dx = op(x, dout)
return (dx,)
return custom_bprop
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
# Define a custom operator of pyfunc type and provide a backpropagation function
self.op = ops.Custom(square, lambda x: x, lambda x: x, bprop=bprop(), func_type="pyfunc")
def construct(self, x):
return self.op(x)
if __name__ == "__main__":
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
dx = ms.grad(Net())(ms.Tensor(x))
print(dx)
The following points need to be explained in this example:
The backpropagation function uses a custom operator of akg type, and the operator definition and use need to be separated, that is, the custom operator is defined outside the
custom_bprop
function and used inside thecustom_bprop
function.
Execute case:
python test_grad.py
The execution result is as follows:
[ 2. 8. 18.]
More examples can be found in the MindSpore source code tests/st/graph_kernel/custom.