mindspore.ops.PrimitiveWithCheck
- class mindspore.ops.PrimitiveWithCheck(name)[source]
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments but used the infer method registered in c++ source codes.
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(), check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
- Parameters
name (str) – Name of the current Primitive.
Examples
>>> # init a Primitive class with check >>> class Flatten(PrimitiveWithCheck): >>> @prim_attr_register >>> def __init__(self): >>> pass >>> def check_shape(self, input_x): >>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) >>> >>> def check_dtype(self, input_x): >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) >>> >>> # init a Primitive obj >>> add = Flatten()
- check_dtype(*args)[source]
Check data types of input args.
- Parameters
args (
mindspore.dtype
) – data type of inputs.- Returns
None.