mindspore.ops.PrimitiveWithCheck

class mindspore.ops.PrimitiveWithCheck(name)[源代码]

PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。

可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。

如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。

参数:
  • name (str) - 当前Primitive的名称。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck
>>> # init a Primitive class with check
>>> class Flatten(PrimitiveWithCheck):
...     @prim_attr_register
...     def __init__(self):
...         pass
...     def check_shape(self, input_x):
...         validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
...
...     def check_dtype(self, input_x):
...         validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
...
>>> # init a Primitive obj
>>> add = Flatten()
check_dtype(*args)[源代码]

检查输入参数的数据类型。

参数:
返回:

None。

check_shape(*args)[源代码]

检查输入参数的shape。

Note

Scalar的shape是一个空元组。

参数:
  • args (tuple(int)) - 输入tensor的shape。

返回:

None。