mindspore.ops.PrimitiveWithCheck
- class mindspore.ops.PrimitiveWithCheck(name)[源代码]
PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。
可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。
如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。
- 参数:
name (str) - 当前Primitive的名称。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> from mindspore import dtype as mstype >>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck >>> # init a Primitive class with check >>> class Flatten(PrimitiveWithCheck): ... @prim_attr_register ... def __init__(self): ... pass ... def check_shape(self, input_x): ... validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) ... ... def check_dtype(self, input_x): ... validator.check_subclass("input_x", input_x, mstype.tensor, self.name) ... >>> # init a Primitive obj >>> add = Flatten()
- check_dtype(*args)[源代码]
检查输入参数的数据类型。
- 参数:
args (
mindspore.dtype
) - 输入的数据类型。
- 返回:
None。