mindspore.ops.Custom
- class mindspore.ops.Custom(func, bprop=None, out_dtype=None, func_type='hybrid', out_shape=None, reg_info=None)[源代码]
Custom 算子是MindSpore自定义算子的统一接口。用户可以利用该接口自行定义MindSpore内置算子库尚未包含的算子。 根据输入函数的不同,你可以创建多个自定义算子,并且把它们用在神经网络中。 关于自定义算子的详细说明和介绍,包括参数的正确书写,见 自定义算子教程 。
警告
这是一个实验性API,后续可能修改或删除。
说明
不同自定义算子的函数类型(func_type)支持的平台类型不同。每种类型支持的平台如下:
“hybrid”: [“Ascend”, “GPU”, “CPU”].
“akg”: [“Ascend”, “GPU”, “CPU”].
“tbe”: [“Ascend”].
“aot”: [“GPU”, “CPU”].
“pyfunc”: [“CPU”].
“julia”: [“CPU”].
“aicpu”: [“Ascend”].
当运行在ge后端时,通过 CustomRegOp 生成”aicpu”和”tbe”类型的自定义算子的算子信息,通过 custom_info_register 将算子信息绑定到”tbe”类型的自定义算子的 func 上,然后将”aicpu”类型的自定义算子的算子信息以及”tbe”类型的自定义算子的 func 实现保存在一个或多个文件里,并且将这些文件保存在一个单独的目录里,在网络运行前将此目录的绝对路径设置到环境变量”MS_DEV_CUSTOM_OPP_PATH”。
- 参数:
func (Union[function, str]) - 自定义算子的函数表达。
out_shape (Union[function, list, tuple]) - 自定义算子的输入的形状或者输出形状的推导函数。默认值:
None
。out_dtype (Union[function,
mindspore.dtype
, tuple[mindspore.dtype
]]) - 自定义算子的输入的数据类型或者输出数据类型的推导函数。默认值:None
。func_type (str) - 自定义算子的函数类型,必须是[
"hybrid"
,"akg"
,"tbe"
,"aot"
,"pyfunc"
,"julia"
,"aicpu"
]中之一。默认值:"hybrid"
。bprop (function) - 自定义算子的反向函数。默认值:
None
。reg_info (Union[str, dict, list, tuple]) - 自定义算子的算子注册信息。默认值:
None
。
- 输入:
input (Union(tuple, list)) - 输入要计算的Tensor。
- 输出:
Tensor。自定义算子的计算结果。
- 异常:
TypeError - 如果输入 func 不合法,或者 func 对应的注册信息类型不对。
ValueError - func_type 的值不在列表内。
ValueError - 算子注册信息不合法,包括支持平台不匹配,算子输入和属性与函数不匹配。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> from mindspore import Tensor, ops >>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, kernel >>> from mindspore import dtype as mstype >>> from mindspore.nn import Cell >>> input_x = Tensor(np.ones([16, 16]).astype(np.float32)) >>> input_y = Tensor(np.ones([16, 16]).astype(np.float32)) >>> >>> # Example, func_type = "hybrid" >>> # This is the default func_type in Custom, >>> # and both out_shape and out_dtype can be None(default value). >>> # In this case, the input func must be a function written in the Hybrid DSL >>> # and decorated by @kernel. >>> @kernel ... def add_script(a, b): ... c = output_tensor(a.shape, a.dtype) ... for i0 in range(a.shape[0]): ... for i1 in range(a.shape[1]): ... c[i0, i1] = a[i0, i1] + b[i0, i1] ... return c >>> >>> test_op_hybrid = ops.Custom(add_script) >>> output = test_op_hybrid(input_x, input_y) >>> # the result will be a 16 * 16 tensor with all elements 2 >>> print(output.shape) (16, 16) >>> # Example, func_type = "tbe" >>> square_with_bias_op_info = CustomRegOp() \ ... .fusion_type("OPAQUE") \ ... .attr("bias", "required", "float") \ ... .input(0, "x") \ ... .output(0, "y") \ ... .dtype_format(DataType.F32_Default, DataType.F32_Default) \ ... .dtype_format(DataType.F16_Default, DataType.F16_Default) \ ... .target("Ascend") \ ... .get_op_info() >>> >>> @custom_info_register(square_with_bias_op_info) ... def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"): ... import te.lang.cce ... from te import tvm ... from topi.cce import util ... ... shape = input_x.get("shape") ... dtype = input_x.get("dtype").lower() ... ... shape = util.shape_refine(shape) ... data = tvm.placeholder(shape, name="data", dtype=dtype) ... ... with tvm.target.cce(): ... res0 = te.lang.cce.vmul(data, data) ... res = te.lang.cce.vadds(res0, bias) ... sch = te.lang.cce.auto_schedule(res) ... ... config = {"print_ir": False, ... "name": kernel_name, ... "tensor_list": [data, res]} ... ... te.lang.cce.cce_build_code(sch, config) >>> >>> def test_tbe(): ... square_with_bias = ops.Custom(square_with_bias, out_shape=lambda x, _: x, \ ... out_dtype=lambda x, _: x, func_type="tbe") ... res = self.square_with_bias(input_x, 1.0) ... return res >>> >>> # Example, func_type = "aicpu" >>> resize_bilinear_op_info = CustomRegOp("ResizeBilinear") \ ... .fusion_type("OPAQUE") \ ... .input(0, "input", "required") \ ... .output(1, "output", "required") \ ... .attr("align_corners", "required", "bool") \ ... .attr("cust_aicpu", "optional", "str", "aicpu_kernels") \ ... .dtype_format(DataType.F32_Default, DataType.F32_Default) \ ... .dtype_format(DataType.F16_Default, DataType.F32_Default) \ ... .target("Ascend") \ ... .get_op_info() >>> >>> @custom_info_register(resize_bilinear_op_info) ... def resize_bilinear_aicpu(): ... return >>> >>> def test_aicpu(x): ... resize_bilinear_op = ops.Custom(resize_bilinear_aicpu, out_shape=[1, 1, 9, 9], \ ... out_dtype=mstype.float32, func_type="aicpu") ... res = resize_bilinear_op(x, True, "aicpu_kernels") ... return res >>> >>> # Example, func_type = "aot" >>> def test_aot(x, y, out_shapes, out_types): ... program = ops.Custom("./reorganize.so:CustomReorganize", out_shapes, out_types, "aot") ... out = program(x, y) ... return out >>> >>> # Example, func_type = "pyfunc" >>> def func_multi_output(x1, x2): ... return (x1 + x2), (x1 - x2) >>> >>> test_pyfunc = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc") >>> output = test_pyfunc(input_x, input_y) >>> >>> # Example, func_type = "julia" >>> # julia code: >>> # add.jl >>> # module Add >>> # function add(x, y, z) >>> # z .= x + y >>> # return z >>> # end >>> # end >>> def test_julia(x, y, out_shapes, out_types): ... program = ops.Custom("./add.jl:Add:add", out_shapes, out_types, "julia") ... out = program(x, y) ... return out