mindspore.ops.Custom

View Source On Gitee
class mindspore.ops.Custom(func, out_shape=None, out_dtype=None, func_type='hybrid', bprop=None, reg_info=None)[source]

Custom primitive is used for user defined operators and is to enhance the expressive ability of built-in primitives. You can construct a Custom object with a predefined function, which describes the computation logic of a user defined operator. You can also construct another Custom object with another predefined function if needed. Then these Custom objects can be directly used in neural networks. Detailed description and introduction of user-defined operators, including correct writing of parameters, please refer to Custom Operators Tutorial .

Warning

  • This is an experimental API that is subject to change.

Note

The supported platforms are determined by the input func_type. The supported platforms are as follows:

  • "hybrid": supports ["GPU", "CPU"].

  • "akg": supports ["GPU", "CPU"].

  • "aot": supports ["GPU", "CPU", "ASCEDN"].

  • "pyfunc": supports ["CPU"].

  • "julia": supports ["CPU"].

Parameters
  • func (Union[function, str]) –

    • function: If func is of function type, then func should be a Python function which describes the computation logic of a user defined operator. The function can be one of the following:

      1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.

      2. A pure python function

      3. An kernel decorated function written by the Hybrid DSL.

    • str: If func is of str type, then str should be a path of file along with a function name. This could be used when func_type is "aot" or "julia".

      1. for "aot":

        a) GPU/CPU platform. "aot" means ahead of time, in which case Custom directly launches user defined "xxx.so" file as an operator. Users need to compile a handwriting "xxx.cu"/"xxx.cc" file into "xxx.so" ahead of time, and offer the path of the file along with a function name.

        • "xxx.so" file generation:

          1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile it.(ex. "nvcc –shared -Xcompiler -fPIC -o add.so add.cu")

          2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to compile it.(ex. "g++ –shared -fPIC -o add.so add.cc")

        • Define a "xxx.cc"/"xxx.cu" file:

          "aot" is a cross-platform identity. The functions defined in "xxx.cc" or "xxx.cu" share the same args. Typically, the function should be as:

          int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
                  void *stream, void *extra)
          

          Parameters:

          • nparam(int): total number of inputs plus outputs; suppose the operator has 2 inputs and 3 outputs, then nparam=5

          • params(void **): a pointer to the array of inputs and outputs' pointer; the pointer type of inputs and outputs is void * ; suppose the operator has 2 inputs and 3 outputs, then the first input's pointer is params[0] and the second output's pointer is params[3]

          • ndims(int *): a pointer to the array of inputs and outputs' dimension num; suppose params[i] is a 1024x1024 tensor and params[j] is a 77x83x4 tensor, then ndims[i]=2, ndims[j]=3.

          • shapes(int64_t **): a pointer to the array of inputs and outputs' shapes(int64_t *); the ith input's jth dimension's size is shapes[i][j](0<=j<ndims[i]); suppose params[i] is a 2x3 tensor and params[j] is a 3x3x4 tensor, then shapes[i][0]=2, shapes[j][2]=4.

          • dtypes(const char **): a pointer to the array of inputs and outputs' types(const char *); (ex. "float32", "float16", "float", "float64", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "bool")

          • stream(void *): stream pointer, only used in cuda file

          • extra(void *): used for further extension

          Return Value(int):

          • 0: MindSpore will continue to run if this aot kernel is successfully executed

          • others: MindSpore will raise exception and exit

          Examples: see details in tests/st/ops/graph_kernel/custom/aot_test_files/

        • Use it in Custom:

          Custom(func="{dir_path}/{file_name}:{func_name}",...)
          (ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32,
          "aot"))
          

        b) ASCEND platform Before using Custom operators on the ASCEND platform, users must first develop custom operators based on Ascend C and compile them. For operator development, you can refer to the tutorial on Quick Start for End-to-End Operator Development, and for compiling custom operators, you can use the Offline Compilation of Ascend C Custom Operators <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom_ascendc.html> tool. When passing the operator's name into the func parameter, taking AddCustom as an example for the name given in the custom operator implementation, there are several ways to use it:

        • Usin TBE: func="AddCustom"

        • Using AclNN: func="aclnnAddCustom"

        • Inferring the shape of the operator through C++ derivation: func="infer_shape.cc:aclnnAddCustom", where infer_shape.cc is the shape derivation implemented in C++.

      2. for "julia":

        Currently "julia" supports CPU(linux only) platform. For julia use JIT compiler, and julia support c api to call julia code. The Custom can directly launches user defined "xxx.jl" file as an operator. Users need to write a "xxx.jl" file which include modules and functions, and offer the path of the file along with a module name and function name.

        Examples: see details in tests/st/ops/graph_kernel/custom/julia_test_files/

        • Use it in Custom:

          Custom(func="{dir_path}/{file_name}:{module_name}:{func_name}",...)
          (ex. Custom(func="./add.jl:Add:add", out_shape=[1], out_dtype=mstype.float32, "julia"))
          

  • out_shape (Union[function, list, tuple]) –

    The output shape infer function or the value of output shape of func. Default: None .

    If func has single output, then the value of output shape is a list or tuple of int.

    If func has multiple outputs, then the value of output shape is a tuple, each item represents the shape of each output.

    The input can be None only when the func_type input is "hybrid". In this case, the automatic infer shape mechanic will be enabled.

  • out_dtype (Union[function, mindspore.dtype, tuple[mindspore.dtype]]) –

    The output data type infer function or the value of output data type of func. Default: None .

    If func has single output, then the value of output shape is a mindspore.dtype.

    If func has multiple outputs, then the value of output shape is a tuple of mindspore.dtype, each item represents the data type of each output.

    The input can be None only when the func_type input is "hybrid". In this case, the automatic infer value mechanic will be enabled.

  • func_type (str) –

    The implementation type of func, should be one of

    [ "hybrid" , "akg" , "aot" , "pyfunc" , "julia" ].

  • bprop (function) – The back propagation function of func. Default: None .

  • reg_info (Union[str, dict, list, tuple]) –

    Represents the registration information(reg info) of func with json format of type str or dict. The reg info specifies supported data types and formats of inputs and outputs, attributes and target of func. Default: None .

    If reg info is a list or tuple, then each item should be with json format of type str or dict, which represents the registration information of func in a specific target. You need to invoke CustomRegOp or the subclass of RegOp to generate the reg info for func. Then you can invoke custom_info_register to bind the reg info to func or just pass the reg info to reg_info parameter. The reg_info parameter takes higher priority than custom_info_register and the reg info in a specific target will be registered only once.

    If reg info is not set, then we will infer the data types and formats from the inputs of Custom operator.

    Please note that, if func_type is "tbe" or the func only supports some specified data types and formats, or it has attribute inputs, then you should set the reg info for func.

Inputs:
  • input (Union(tuple, list)) - The input tuple or list is made up of multiple tensors, and attributes value(optional).

Outputs:

Tensor or tuple[Tensor], execution results.

Raises
  • TypeError – If the type of func is invalid or the type of register information for func is invalid.

  • ValueError – If func_type is invalid.

  • ValueError – If the register information is invalid, including the target is not supported, the input numbers or the attributes of func differs in different targets.

Supported Platforms:

GPU CPU ASCEND

Examples

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, kernel
>>> from mindspore import dtype as mstype
>>> from mindspore.nn import Cell
>>> input_x = Tensor(np.ones([16, 16]).astype(np.float32))
>>> input_y = Tensor(np.ones([16, 16]).astype(np.float32))
>>>
>>> # Example, func_type = "hybrid"
>>> # This is the default func_type in Custom,
>>> # and both out_shape and out_dtype can be None(default value).
>>> # In this case, the input func must be a function written in the Hybrid DSL
>>> # and decorated by @kernel.
>>> @kernel
... def add_script(a, b):
...     c = output_tensor(a.shape, a.dtype)
...     for i0 in range(a.shape[0]):
...         for i1 in range(a.shape[1]):
...             c[i0, i1] = a[i0, i1] + b[i0, i1]
...     return c
>>>
>>> test_op_hybrid = ops.Custom(add_script)
>>> output = test_op_hybrid(input_x, input_y)
>>> # the result will be a 16 * 16 tensor with all elements 2
>>> print(output.shape)
(16, 16)
>>> # Example, func_type = "aot"
>>> def test_aot(x, y, out_shapes, out_types):
...     program = ops.Custom("./reorganize.so:CustomReorganize", out_shapes, out_types, "aot")
...     out = program(x, y)
...     return out
>>>
>>> # Example, func_type = "pyfunc"
>>> def func_multi_output(x1, x2):
...     return (x1 + x2), (x1 - x2)
>>>
>>> test_pyfunc = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc")
>>> output = test_pyfunc(input_x, input_y)
>>>
>>> # Example, func_type = "julia"
>>> # julia code:
>>> # add.jl
>>> # module Add
>>> # function add(x, y, z)
>>> #   z .= x + y
>>> #   return z
>>> # end
>>> # end
>>> def test_julia(x, y, out_shapes, out_types):
...     program = ops.Custom("./add.jl:Add:add", out_shapes, out_types, "julia")
...     out = program(x, y)
...     return out