mindspore.ops.PrimitiveWithCheck
- class mindspore.ops.PrimitiveWithCheck(name)[source]
PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments of operators, but uses the infer method registered in c++ source codes.
There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(), check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
More on how to customize a Op, please refer to Custom Operators.
- Parameters
name (str) – Name of the current Primitive.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> from mindspore import dtype as mstype >>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck >>> # init a Primitive class with check >>> class Flatten(PrimitiveWithCheck): ... @prim_attr_register ... def __init__(self): ... pass ... def check_shape(self, input_x): ... Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name) ... ... def check_dtype(self, input_x): ... Validator.check_subclass("input_x", input_x, mstype.tensor_type, self.name) ... >>> # init a Primitive obj >>> add = Flatten()
- check_dtype(*args)[source]
Check data types of input args.
- Parameters
args (
mindspore.dtype
) – data type of inputs.- Returns
None.