mindspore.ops.PrimitiveWithCheck

View Source On Gitee
class mindspore.ops.PrimitiveWithCheck(name)[source]

PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments of operators, but uses the infer method registered in c++ source codes.

There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(), check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.

Parameters

name (str) – Name of the current Primitive.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import dtype as mstype
>>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck
>>> # init a Primitive class with check
>>> class Flatten(PrimitiveWithCheck):
...     @prim_attr_register
...     def __init__(self):
...         pass
...     def check_shape(self, input_x):
...         Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name)
...
...     def check_dtype(self, input_x):
...         Validator.check_subclass("input_x", input_x, mstype.tensor_type, self.name)
...
>>> # init a Primitive obj
>>> add = Flatten()
check_dtype(*args)[source]

Check data types of input args.

Parameters

args (mindspore.dtype) – data type of inputs.

Returns

None.

check_shape(*args)[source]

Check shapes of input args.

Note

The shape of scalar is an empty tuple.

Parameters

args (tuple(int)) – shapes of input tensors.

Returns

None.